1use ndarray::{ArrayD, IxDyn};
2
3#[derive(Clone, Debug)]
6pub struct Tensor {
7 pub data: ArrayD<f32>,
8}
9
10impl Tensor {
11 pub fn new(data: Vec<f32>, shape: &[usize]) -> Self {
14 Self {
15 data: ArrayD::from_shape_vec(IxDyn(shape), data)
16 .expect("Shape does not match data length"),
17 }
18 }
19
20 pub fn zeros(shape: &[usize]) -> Self {
22 Self {
23 data: ArrayD::zeros(IxDyn(shape)),
24 }
25 }
26
27 pub fn ones(shape: &[usize]) -> Self {
29 Self {
30 data: ArrayD::from_elem(IxDyn(shape), 1.0),
31 }
32 }
33
34 pub fn random(shape: &[usize]) -> Self {
36 use rand::distributions::Uniform;
37 use rand::Rng;
38
39 let size: usize = shape.iter().product();
40 let mut rng = rand::thread_rng();
41 let dist = Uniform::new(0.0, 1.0);
42 let data: Vec<f32> = (0..size).map(|_| rng.sample(&dist)).collect();
43 Self::new(data, shape)
44 }
45
46 pub fn from_scalar(value: f32) -> Self {
48 Self {
49 data: ArrayD::from_elem(IxDyn(&[]), value),
50 }
51 }
52
53 pub fn shape(&self) -> &[usize] {
55 self.data.shape()
56 }
57
58 pub fn sum(&self) -> f32 {
60 self.data.sum()
61 }
62
63 pub fn print(&self) {
65 println!("{:?}", self.data);
66 }
67
68 pub fn add(&self, other: &Self) -> Self {
70 Self {
71 data: &self.data + &other.data,
72 }
73 }
74
75 pub fn sub(&self, other: &Self) -> Self {
77 Self {
78 data: &self.data - &other.data,
79 }
80 }
81
82 pub fn mul(&self, other: &Self) -> Self {
84 Self {
85 data: &self.data * &other.data,
86 }
87 }
88
89 pub fn div(&self, other: &Self) -> Self {
91 Self {
92 data: &self.data / &other.data,
93 }
94 }
95
96 pub fn relu(&self) -> Self {
100 Self {
101 data: self.data.mapv(|x| if x > 0.0 { x } else { 0.0 }),
102 }
103 }
104
105 pub fn sigmoid(&self) -> Self {
107 Self {
108 data: self.data.mapv(|x| 1.0 / (1.0 + (-x).exp())),
109 }
110 }
111}