1use elara_log::prelude::*;
2use ndarray::prelude::*;
3use ndarray_rand::rand_distr::Uniform;
4use ndarray_rand::RandomExt;
5
6use std::{
9 cell::{Ref, RefCell, RefMut},
10 collections::HashSet,
11 fmt::Debug,
12 hash::{Hash, Hasher},
13 ops::{Add, AddAssign, Deref, DerefMut, Div, DivAssign, Mul, MulAssign, Sub, SubAssign},
14 rc::Rc,
15};
16
17use uuid::Uuid;
18
19#[macro_export]
22macro_rules! count {
23 [$($x:expr),*] => {
24 vec![$($x),*].len()
25 }
26}
27
28#[macro_export]
30macro_rules! tensor {
31 [$([$($x:expr),* $(,)*]),+ $(,)*] => {
32 Tensor::new(ndarray::array!($([$($x,)*],)*))
33 };
34 [$($x:expr),*] => {
35 Tensor::new(ndarray::array!($($x),*).into_shape(($crate::count![$($x),*], 1)).unwrap())
36 };
37}
38
39#[macro_export]
41macro_rules! scalar {
42 ($x:expr) => {
43 Tensor::from_f64($x)
44 };
45}
46
47pub struct TensorData {
49 pub data: Array2<f64>,
50 pub grad: Array2<f64>,
51 pub uuid: Uuid,
52 backward: Option<fn(&TensorData)>,
53 prev: Vec<Tensor>,
54 op: Option<String>,
55}
56
57#[derive(Clone)]
59pub struct Tensor(Rc<RefCell<TensorData>>);
60
61impl Hash for Tensor {
62 fn hash<H: Hasher>(&self, state: &mut H) {
63 self.borrow().uuid.hash(state);
64 }
65}
66
67impl PartialEq for Tensor {
68 fn eq(&self, other: &Self) -> bool {
69 self.borrow().uuid == other.borrow().uuid
70 }
71}
72
73impl Eq for Tensor {}
74
75impl Deref for Tensor {
76 type Target = Rc<RefCell<TensorData>>;
77 fn deref(&self) -> &Self::Target {
78 &self.0
79 }
80}
81
82impl DerefMut for Tensor {
83 fn deref_mut(&mut self) -> &mut Self::Target {
84 &mut self.0
85 }
86}
87
88impl TensorData {
89 fn new(data: Array2<f64>) -> TensorData {
90 let shape = data.raw_dim();
91 TensorData {
92 data,
93 grad: Array2::zeros(shape),
94 uuid: Uuid::new_v4(),
95 backward: None,
96 prev: Vec::new(),
97 op: None,
98 }
99 }
100}
101
102impl Tensor {
103 pub fn new(array: Array2<f64>) -> Tensor {
105 Tensor(Rc::new(RefCell::new(TensorData::new(array))))
106 }
107
108 pub fn shape(&self) -> (usize, usize) {
110 self.borrow().data.dim()
111 }
112
113 pub fn rand(shape: [usize; 2]) -> Tensor {
115 let arr: Array2<f64> = Array2::random((shape[0], shape[1]), Uniform::new(0., 1.));
116 Tensor::new(arr)
117 }
118
119 pub fn from_f64(val: f64) -> Tensor {
121 Tensor::new(array![[val]])
122 }
123
124 pub fn ones(shape: [usize; 2]) -> Tensor {
126 let arr: Array2<f64> = Array2::ones((shape[0], shape[1]));
127 Tensor::new(arr)
128 }
129
130 pub fn zeros(shape: [usize; 2]) -> Tensor {
132 let arr: Array2<f64> = Array2::zeros((shape[0], shape[1]));
133 Tensor::new(arr)
134 }
135
136 pub fn update(&self, lr: f64) {
140 let mut data = self.inner_mut();
141 let grad = data.grad.clone();
142 data.data.scaled_add(-lr, &grad);
143 }
144
145 pub fn arange<I: Iterator<Item = i32>>(range: I, shape: [usize; 2]) -> Tensor {
147 let arr = Array::from_iter(range)
148 .mapv(|el| el as f64)
149 .into_shape((shape[0], shape[1]))
150 .unwrap();
151 Tensor::new(arr)
152 }
153
154 pub fn linspace(start: f64, end: f64, num: usize) -> Tensor {
157 let arr = Array::linspace(start, end, num);
158 let arr_reshaped = arr.into_shape((num, 1)).unwrap();
159 Tensor::new(arr_reshaped)
160 }
161
162 pub fn reshape(&mut self, shape: [usize; 2]) -> Tensor {
164 Tensor::new(self.data().clone().into_shape(shape).unwrap())
165 }
166
167 pub fn len(&self) -> usize {
169 self.data().len()
170 }
171
172 pub fn sum(&self) -> Tensor {
174 let sum = self.data().sum();
175 let out = Tensor::from_f64(sum);
176 out.inner_mut().prev = vec![self.clone()];
177 out.inner_mut().op = Some(String::from("sum"));
178 out.inner_mut().backward = Some(|value: &TensorData| {
179 value.prev[0].grad_mut().scaled_add(1.0, &value.grad);
180 });
181 out
182 }
183
184 pub fn mean(&self) -> Tensor {
186 (1.0 / self.data().len() as f64) * self.sum()
187 }
188
189 pub fn exp(&self) -> Tensor {
191 let exp_array = self.borrow().data.mapv(|val| val.exp());
192 let out = Tensor::new(exp_array);
193 out.inner_mut().prev = vec![self.clone()];
194 out.inner_mut().op = Some(String::from("exp"));
195 out.inner_mut().backward = Some(|value: &TensorData| {
196 let prev = value.prev[0].borrow().data.clone();
197 value.prev[0]
198 .grad_mut()
199 .scaled_add(1.0, &prev.mapv(|val| val.exp()));
200 });
201 out
202 }
203
204 pub fn relu(&self) -> Tensor {
206 let relu_array = self.data().mapv(|val| val.max(0.0));
207 let out = Tensor::new(relu_array);
208 out.inner_mut().prev = vec![self.clone()];
209 out.inner_mut().op = Some(String::from("ReLU"));
210 out.inner_mut().backward = Some(|value: &TensorData| {
211 let dv = value.prev[0]
212 .data()
213 .mapv(|x| if x > 0.0 { 1.0 } else { 0.0 });
214 value.prev[0].grad_mut().scaled_add(1.0, &dv);
215 });
216 out
217 }
218
219 pub fn pow(&self, power: f64) -> Tensor {
221 let pow_array = self.data().mapv(|val| val.powf(power));
222 let out = Tensor::new(pow_array);
223 out.inner_mut().prev = vec![self.clone(), Tensor::from_f64(power)];
224 out.inner_mut().op = Some(String::from("^"));
225 out.inner_mut().backward = Some(|value: &TensorData| {
226 let base_vec = value.prev[0]
227 .data()
228 .mapv(|val| val.powf(value.prev[1].data()[[0, 0]] - 1.0));
229 value.prev[0].grad_mut().scaled_add(
230 1.0,
231 &(value.prev[1].data().deref() * base_vec * value.grad.clone()),
232 );
233 });
234 out
235 }
236
237 pub fn sigmoid(&self) -> Tensor {
239 let sigmoid_array = self.borrow().data.mapv(|val| 1.0 / (1.0 + (-val).exp()));
240 let out = Tensor::new(sigmoid_array);
241 out.inner_mut().prev = vec![self.clone()];
242 out.inner_mut().op = Some(String::from("exp"));
243 out.inner_mut().backward = Some(|value: &TensorData| {
244 let prev = value.prev[0].borrow().data.clone();
245 let exp_array = prev.mapv(|val| val.exp() / (1.0 + val.exp()).powf(2.0));
246 value.prev[0].inner_mut().grad.scaled_add(1.0, &exp_array);
247 });
248 out
249 }
250
251 pub fn matmul(&self, rhs: &Tensor) -> Tensor {
253 let a_shape = self.shape();
254 let b_shape = rhs.shape();
255 if a_shape.1 != b_shape.0 {
256 error!("You are attempting to matrix-multiply two matrices of size {} x {} and {} x {}. These shapes are not compatible.", a_shape.0, a_shape.1, b_shape.0, b_shape.1);
257 }
258 let res: Array2<f64> = self.data().dot(rhs.data().deref());
259 let out = Tensor::new(res);
260 out.inner_mut().prev = vec![self.clone(), rhs.clone()];
261 out.inner_mut().op = Some(String::from("matmul"));
262 out.inner_mut().backward = Some(|value: &TensorData| {
263 let da = value.grad.dot(&value.prev[1].data().t());
264 let db = value.prev[0].data().t().dot(&value.grad);
265 value.prev[0].grad_mut().scaled_add(1.0, &da);
266 value.prev[1].grad_mut().scaled_add(1.0, &db);
267 });
268 out
269 }
270
271 pub fn inner(&self) -> Ref<TensorData> {
273 (*self.0).borrow()
274 }
275
276 pub fn inner_mut(&self) -> RefMut<TensorData> {
279 (*self.0).borrow_mut()
280 }
281
282 pub fn data(&self) -> impl Deref<Target = Array2<f64>> + '_ {
284 Ref::map((*self.0).borrow(), |mi| &mi.data)
285 }
286
287 pub fn data_mut(&self) -> impl DerefMut<Target = Array2<f64>> + '_ {
290 RefMut::map((*self.0).borrow_mut(), |mi| &mut mi.data)
291 }
292
293 pub fn grad(&self) -> impl Deref<Target = Array2<f64>> + '_ {
296 Ref::map((*self.0).borrow(), |mi| &mi.grad)
297 }
298
299 pub fn grad_mut(&self) -> impl DerefMut<Target = Array2<f64>> + '_ {
302 RefMut::map((*self.0).borrow_mut(), |mi| &mut mi.grad)
303 }
304
305 pub fn zero_grad(&self) {
307 self.grad_mut().fill(0.0);
308 }
309
310 pub fn backward(&self) {
312 let mut topo: Vec<Tensor> = vec![];
313 let mut visited: HashSet<Tensor> = HashSet::new();
314 self._build_topo(&mut topo, &mut visited);
315 topo.reverse();
316
317 self.grad_mut().fill(1.0);
318 for v in topo {
319 if let Some(backprop) = v.borrow().backward {
320 backprop(&v.borrow());
321 }
322 }
323 }
324
325 fn _build_topo(&self, topo: &mut Vec<Tensor>, visited: &mut HashSet<Tensor>) {
326 if visited.insert(self.clone()) {
327 self.borrow().prev.iter().for_each(|child| {
328 child._build_topo(topo, visited);
329 });
330 topo.push(self.clone());
331 }
332 }
333
334 pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
337 let data = self.data();
338 (0..data.shape()[0]).map(move |i| {
339 let el = data.index_axis(Axis(0), i);
340 let reshaped_and_cloned_el = el
341 .into_shape((el.shape()[0], 1))
342 .unwrap()
343 .mapv(|el| el.clone());
344 Tensor::new(reshaped_and_cloned_el)
345 })
346 }
347}
348
349impl Iterator for Tensor {
350 type Item = Tensor;
351 fn next(&mut self) -> Option<Self::Item> {
352 Some(self.iter().next().unwrap())
353 }
354}
355
356impl Debug for Tensor {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 write!(f, "{:?}", self.data().deref(),)
360 }
361}
362
363macro_rules! impl_binary_op {
367 [$trait:ident, $op_name:ident, $op:tt] => {
368 impl $trait for Tensor {
369 type Output = Tensor;
370
371 fn $op_name(self, rhs: Tensor) -> Self::Output {
372 &self $op &rhs
373 }
374 }
375
376 impl $trait<f64> for &Tensor {
377 type Output = Tensor;
378
379 fn $op_name(self, rhs: f64) -> Self::Output {
380 self $op &Tensor::from_f64(rhs)
381 }
382 }
383
384 impl $trait<f64> for Tensor {
385 type Output = Tensor;
386
387 fn $op_name(self, rhs: f64) -> Self::Output {
388 &self $op rhs
389 }
390 }
391
392 impl $trait<&Tensor> for f64 {
393 type Output = Tensor;
394
395 fn $op_name(self, rhs: &Tensor) -> Self::Output {
396 &Tensor::from_f64(self) $op rhs
397 }
398 }
399
400 impl $trait<Tensor> for f64 {
401 type Output = Tensor;
402
403 fn $op_name(self, rhs: Tensor) -> Self::Output {
404 self $op &rhs
405 }
406 }
407 };
408
409 [$trait:ident, $op_name:ident, $op:tt, $update_grad:expr] => {
410 impl $trait for &Tensor {
411 type Output = Tensor;
412
413 fn $op_name(self, rhs: &Tensor) -> Self::Output {
414 let out = Tensor::new(self.data().deref() $op rhs.data().deref());
415 out.inner_mut().prev = vec![self.clone(), rhs.clone()];
416 out.inner_mut().op = Some(stringify!($op_name).to_string());
417 out.inner_mut().backward = Some(|value: &TensorData| {
418 let (dv1, dv2) = $update_grad(&value.grad, value.prev[0].data().deref(), value.prev[1].data().deref());
419
420 let dv1 = match value.prev[0].grad().dim() {
421 (1, 1) => arr2(&[[dv1.sum()]]),
422 (1, n) => dv1.sum_axis(Axis(0)).into_shape((1, n)).unwrap(),
423 (n, 1) => dv1.sum_axis(Axis(1)).into_shape((n, 1)).unwrap(),
424 (_, _) => dv1,
425 };
426 let dv2 = match value.prev[1].grad().dim() {
427 (1, 1) => arr2(&[[dv2.sum()]]),
428 (1, n) => dv2.sum_axis(Axis(0)).into_shape((1, n)).unwrap(),
429 (n, 1) => dv2.sum_axis(Axis(1)).into_shape((n, 1)).unwrap(),
430 (_, _) => dv2,
431 };
432
433 value.prev[0].grad_mut().scaled_add(1.0, &dv1);
434 value.prev[1].grad_mut().scaled_add(1.0, &dv2);
435 });
436 out
437 }
438 }
439
440 impl_binary_op![$trait, $op_name, $op];
441 };
442}
443
444macro_rules! impl_assignment_op {
449 [$trait:ident, $op_name:ident, $op:tt] => {
450 impl $trait for Tensor {
451 fn $op_name(&mut self, rhs: Tensor) {
452 *self = self.clone() $op rhs;
453 }
454 }
455
456 impl $trait<f64> for Tensor {
457 fn $op_name(&mut self, rhs: f64) {
458 *self = self.clone() $op rhs;
459 }
460 }
461 };
462
463 }
493
494impl_binary_op![Add, add, +, |grad, _a, _b| { (grad * 1.0, grad * 1.0) }];
495impl_binary_op![Sub, sub, -, |grad, _a, _b| { (grad * 1.0, grad * -1.0) }];
496impl_binary_op![Mul, mul, *, |grad, a, b| { (grad * b, grad * a) }];
497impl_binary_op![Div, div, /, |grad, a, b| { (grad * 1.0 / b, grad * -1.0 * a / (b * b)) }];
498
499impl_assignment_op![AddAssign, add_assign, +];
500impl_assignment_op![SubAssign, sub_assign, -];
501impl_assignment_op![MulAssign, mul_assign, *];
502impl_assignment_op![DivAssign, div_assign, /];