1use crate::vec_tools::ValidNumber;
2use std::fmt::Display;
3use std::ops::{Add, Div, Mul, Sub};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
6pub struct Tensor<T: ValidNumber<T>> {
7 pub shape: Vec<usize>,
8 pub data: Vec<T>,
9}
10
11impl<T: ValidNumber<T>> Tensor<T> {
13 pub fn new(shape: Vec<usize>) -> Tensor<T> {
14 let size = shape.iter().fold(1, |size, x| size * x);
15 Tensor {
16 shape: shape,
17 data: vec![T::from(0.0); size],
18 }
19 }
20
21 pub fn rank(&self) -> usize {
22 self.shape.len()
23 }
24
25 pub fn shape(&self) -> &Vec<usize> {
26 &self.shape
27 }
28
29 pub fn iter(&self) -> core::slice::Iter<'_, T> {
30 self.data.iter()
31 }
32
33 pub fn calculate_data_index(&self, loc: &[usize]) -> usize {
34 loc.into_iter()
35 .rev()
36 .enumerate()
37 .fold(0, |data_idx, (loc_idx, loc_val)| {
38 data_idx
39 + loc_val
40 * match loc_idx {
41 0 => 1,
42 _ => self
43 .shape
44 .clone()
45 .iter()
46 .rev()
47 .take(loc_idx)
48 .fold(1, |prod, x| prod * x),
49 }
50 })
51 }
52
53 pub fn get(&self, loc: &[usize]) -> Option<&T> {
54 if loc.len() != self.rank() {
55 return None;
56 }
57
58 let idx = self.calculate_data_index(loc);
59 self.data.get(idx)
60 }
61
62 pub fn get_mut(&mut self, loc: &[usize]) -> Option<&mut T> {
63 if loc.len() != self.rank() {
64 return None;
65 }
66
67 let idx = self.calculate_data_index(loc);
68 self.data.get_mut(idx)
69 }
70
71 pub fn elementwise_product(&self, other: &Tensor<T>) -> Result<Tensor<T>, ()> {
72 if self.shape != other.shape {
73 return Err(());
74 }
75
76 let mut out = Tensor::new(self.shape.clone());
77 for i in 0..self.data.len() {
78 out.data[i] = self.data[i] * other.data[i];
79 }
80
81 Ok(out)
82 }
83}
84
85impl<T: ValidNumber<T>> Add<Tensor<T>> for Tensor<T> {
86 type Output = Tensor<T>;
87
88 fn add(self, rhs: Tensor<T>) -> Self::Output {
89 if self.shape != rhs.shape {
90 panic!("Cannot add tensors of different shape: lhs {self:?}, rhs: {rhs:?}")
91 }
92
93 Tensor {
94 shape: self.shape,
95 data: self
96 .data
97 .iter()
98 .zip(rhs.data.iter())
99 .map(|(x, y)| *x + *y)
100 .collect(),
101 }
102 }
103}
104
105impl<T: ValidNumber<T>> Sub<Tensor<T>> for Tensor<T> {
106 type Output = Tensor<T>;
107
108 fn sub(self, rhs: Tensor<T>) -> Self::Output {
109 if self.shape != rhs.shape {
110 panic!("Cannot subtract tensors of different shape: lhs {self:?}, rhs: {rhs:?}")
111 }
112
113 Tensor {
114 shape: self.shape,
115 data: self
116 .data
117 .iter()
118 .zip(rhs.data.iter())
119 .map(|(x, y)| *x - *y)
120 .collect(),
121 }
122 }
123}
124
125impl<T: ValidNumber<T>> Mul<T> for Tensor<T> {
126 type Output = Tensor<T>;
127
128 fn mul(self, rhs: T) -> Self::Output {
129 Tensor {
130 shape: self.shape,
131 data: self.data.iter().map(|x| *x * rhs).collect(),
132 }
133 }
134}
135
136impl<T: ValidNumber<T>> Div<T> for Tensor<T> {
137 type Output = Tensor<T>;
138
139 fn div(self, rhs: T) -> Self::Output {
140 if rhs == T::from(0.0) {
141 panic!("Dividing by 0!")
142 }
143
144 Tensor {
145 shape: self.shape,
146 data: self.data.iter().map(|x| *x * rhs).collect(),
147 }
148 }
149}
150
151impl<T: ValidNumber<T>> From<Vec<T>> for Tensor<T> {
152 fn from(value: Vec<T>) -> Self {
153 Tensor {
154 shape: vec![value.len()],
155 data: value,
156 }
157 }
158}
159
160impl<T: ValidNumber<T>> From<Vec<Vec<T>>> for Tensor<T> {
161 fn from(value: Vec<Vec<T>>) -> Self {
162 let shape = vec![value.len(), value[0].len()];
163 let data: Vec<T> = value.into_iter().fold(vec![], |mut data, mut x| {
164 data.append(&mut x);
165 data
166 });
167
168 Tensor { shape, data }
169 }
170}
171
172impl<T: ValidNumber<T>> From<Vec<Vec<Vec<T>>>> for Tensor<T> {
173 fn from(value: Vec<Vec<Vec<T>>>) -> Self {
174 let shape = vec![value.len(), value[0].len(), value[0][1].len()];
175 let mut data: Vec<T> = vec![];
176
177 for layer in value {
178 for row in layer {
179 for item in row {
180 data.push(item)
181 }
182 }
183 }
184
185 Tensor { shape, data }
186 }
187}
188
189impl<T: ValidNumber<T>> Tensor<T> {
192 pub fn dot_product(&self, other: &Tensor<T>) -> Result<T, ()> {
193 if self.rank() != 1 || (self.shape() != other.shape()) {
194 return Err(());
195 }
196
197 Ok(self
198 .iter()
199 .zip(other.iter())
200 .fold(T::from(0.0), |res, (s, o)| res + *s * *o))
201 }
202}
203
204impl<T: ValidNumber<T>> Tensor<T> {
205 pub fn row_count(&self) -> usize {
206 self.shape[0]
207 }
208
209 pub fn col_count(&self) -> usize {
210 self.shape[1]
211 }
212
213 pub fn get_2dims(&self) -> (usize, usize) {
214 if self.rank() != 2 {
215 panic!("only defined for rank 2 tensors!")
216 }
217 (self.shape[0], self.shape[1])
218 }
219
220 pub fn as_rows(&self) -> Vec<Tensor<T>> {
221 let (_, cols) = self.get_2dims();
222
223 self.data
224 .chunks_exact(cols)
225 .map(|x| Tensor::from(x.to_vec()))
226 .collect()
227 }
228
229 pub fn as_columns(&self) -> Vec<Tensor<T>> {
230 let (rows, cols) = self.get_2dims();
231
232 (0..cols)
233 .map(|x| {
234 self.data
235 .iter()
236 .skip(x)
237 .step_by(cols)
238 .take(rows)
239 .cloned()
240 .collect()
241 })
242 .map(|x: Vec<T>| Tensor::from(x))
243 .collect()
244 }
245
246 pub fn matrix_multiply(&self, other: &Tensor<T>) -> Result<Tensor<T>, ()> {
247 if self.shape[1] != other.shape[0] {
248 return Err(());
249 }
250
251 let new_data: Vec<Vec<T>> = self
252 .as_rows()
253 .iter()
254 .map(|row| {
255 other
256 .as_columns()
257 .iter()
258 .map(|col| row.dot_product(col))
259 .collect::<Result<Vec<T>, ()>>()
260 })
261 .collect::<Result<Vec<Vec<T>>, ()>>()?;
262
263 Ok(Tensor::from(new_data))
264 }
265
266 pub fn transposed(&self) -> Tensor<T> {
267 Tensor {
268 shape: vec![self.shape[1], self.shape[0]],
269 data: self.data.clone(),
270 }
271 }
272
273 pub fn column(data: Vec<T>) -> Tensor<T> {
274 let data: Vec<Vec<T>> = data.into_iter().map(|x| vec![x]).collect();
275
276 Tensor::from(data)
277 }
278}
279
280impl<T: ValidNumber<T>> Display for Tensor<T> {
281 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
282 let (rows, cols) = self.get_2dims();
283
284 for row in 0..rows {
285 for col in 0..cols {
286 write!(f, "{:?} ", self.get(&[row, col]).unwrap())?;
287 }
288 writeln!(f)?;
289 }
290
291 Ok(())
292 }
293}
294
295#[cfg(test)]
296mod test {
297 use super::*;
298
299 fn get_generic_tensor2d() -> Tensor<f64> {
300 Tensor::<f64>::from(vec![
301 vec![1.0, 2.0, 3.0],
302 vec![4.0, 5.0, 6.0],
303 vec![7.0, 8.0, 9.0],
304 ])
305 }
306
307 #[test]
308 fn dot_product_test() {
309 let x = Tensor::<f64>::from(vec![1.0, 2.0, 3.0]);
310 let y = Tensor::<f64>::from(vec![1.0, 2.0, 3.0]);
311
312 let z = x.dot_product(&y);
313
314 assert_eq!(z.unwrap(), 14.0)
315 }
316
317 #[test]
318 fn elementwise2d_test() {
319 let x = Tensor::<f64>::from(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
320
321 let y = Tensor::<f64>::from(vec![vec![1.0, 1.0, 1.0], vec![0.0, 1.0, 1.0]]);
322
323 let res = x
324 .elementwise_product(&y)
325 .expect("Incorrect Dimension Error");
326
327 assert_eq!(
328 Tensor::<f64>::from(vec![vec![1.0, 2.0, 3.0], vec![0.0, 5.0, 6.0]]),
329 res
330 )
331 }
332
333 #[test]
334 fn incorrect_elementwise2d_test() {
335 let x = Tensor::<f64>::from(vec![vec![2.0, 3.0]]);
336 let y = Tensor::<f64>::from(vec![vec![1.0, 2.0, 3.0]]);
337
338 let res = x.elementwise_product(&y);
339
340 assert!(res.is_err())
341 }
342
343 #[test]
344 fn as_rows_test() {
345 let x = Tensor::<f64>::from(vec![
346 vec![1.0, 2.0, 3.0],
347 vec![4.0, 5.0, 6.0],
348 vec![7.0, 8.0, 9.0],
349 ]);
350
351 assert_eq!(x.as_rows()[0], Tensor::<f64>::from(vec![1.0, 2.0, 3.0]));
352 assert_eq!(x.as_rows()[1], Tensor::<f64>::from(vec![4.0, 5.0, 6.0]));
353 assert_eq!(x.as_rows()[2], Tensor::<f64>::from(vec![7.0, 8.0, 9.0]))
354 }
355
356 #[test]
357 fn tensor2d_matmul_1() {
358 let x = get_generic_tensor2d();
359 let y = get_generic_tensor2d();
360 let res = x.matrix_multiply(&y).expect("e");
361
362 println!("{res}");
363
364 assert_eq!(
365 Tensor::<f64>::from(vec![
366 vec![30.0, 36.0, 42.0],
367 vec![66.0, 81.0, 96.0],
368 vec![102.0, 126.0, 150.0]
369 ]),
370 res
371 )
372 }
373
374 #[test]
375 fn index_calc_test() {
376 let data: Vec<f64> = (0..=8).into_iter().map(|x| x as f64).collect();
377
378 let tensor1d = Tensor {
379 shape: vec![8],
380 data: data.clone(),
381 };
382
383 let tensor2d = Tensor {
384 shape: vec![2, 4],
385 data: data.clone(),
386 };
387
388 let tensor3d = Tensor {
389 shape: vec![2, 2, 2],
390 data: data.clone(),
391 };
392
393 assert_eq!(tensor1d.calculate_data_index(&[2]), 2);
394 assert_eq!(tensor2d.calculate_data_index(&[1, 1]), 5);
395 assert_eq!(tensor3d.calculate_data_index(&[1, 1, 0]), 6);
396 }
397
398 #[test]
399 fn main() {
400 let x = get_generic_tensor2d();
401 let y = get_generic_tensor2d();
402 let res = x - y;
403 println!("{res}");
404 assert!(true)
405 }
406}