elara_math/tensor/
mod.rs

1use elara_log::prelude::*;
2use ndarray::prelude::*;
3use ndarray_rand::rand_distr::Uniform;
4use ndarray_rand::RandomExt;
5
6// use crate::randf;
7
8use std::{
9    cell::{Ref, RefCell, RefMut},
10    collections::HashSet,
11    fmt::Debug,
12    hash::{Hash, Hasher},
13    ops::{Add, AddAssign, Deref, DerefMut, Div, DivAssign, Mul, MulAssign, Sub, SubAssign},
14    rc::Rc,
15};
16
17use uuid::Uuid;
18
19/// A macro for counting the number of args
20/// passed to it
21#[macro_export]
22macro_rules! count {
23    [$($x:expr),*] => {
24        vec![$($x),*].len()
25    }
26}
27
28/// Macro for quickly creating tensors
29#[macro_export]
30macro_rules! tensor {
31    [$([$($x:expr),* $(,)*]),+ $(,)*] => {
32        Tensor::new(ndarray::array!($([$($x,)*],)*))
33    };
34    [$($x:expr),*] => {
35        Tensor::new(ndarray::array!($($x),*).into_shape(($crate::count![$($x),*], 1)).unwrap())
36    };
37}
38
39/// Macro for quickly creating scalar tensors
40#[macro_export]
41macro_rules! scalar {
42    ($x:expr) => {
43        Tensor::from_f64($x)
44    };
45}
46
47/// Backing data for `Tensor`
48pub struct TensorData {
49    pub data: Array2<f64>,
50    pub grad: Array2<f64>,
51    pub uuid: Uuid,
52    backward: Option<fn(&TensorData)>,
53    prev: Vec<Tensor>,
54    op: Option<String>,
55}
56
57/// A PyTorch-like differentiable tensor type
58#[derive(Clone)]
59pub struct Tensor(Rc<RefCell<TensorData>>);
60
61impl Hash for Tensor {
62    fn hash<H: Hasher>(&self, state: &mut H) {
63        self.borrow().uuid.hash(state);
64    }
65}
66
67impl PartialEq for Tensor {
68    fn eq(&self, other: &Self) -> bool {
69        self.borrow().uuid == other.borrow().uuid
70    }
71}
72
73impl Eq for Tensor {}
74
75impl Deref for Tensor {
76    type Target = Rc<RefCell<TensorData>>;
77    fn deref(&self) -> &Self::Target {
78        &self.0
79    }
80}
81
82impl DerefMut for Tensor {
83    fn deref_mut(&mut self) -> &mut Self::Target {
84        &mut self.0
85    }
86}
87
88impl TensorData {
89    fn new(data: Array2<f64>) -> TensorData {
90        let shape = data.raw_dim();
91        TensorData {
92            data,
93            grad: Array2::zeros(shape),
94            uuid: Uuid::new_v4(),
95            backward: None,
96            prev: Vec::new(),
97            op: None,
98        }
99    }
100}
101
102impl Tensor {
103    /// Create a new tensor from an `Array2`
104    pub fn new(array: Array2<f64>) -> Tensor {
105        Tensor(Rc::new(RefCell::new(TensorData::new(array))))
106    }
107
108    /// Find the shape of a tensor
109    pub fn shape(&self) -> (usize, usize) {
110        self.borrow().data.dim()
111    }
112
113    /// Create a tensor filled with random values
114    pub fn rand(shape: [usize; 2]) -> Tensor {
115        let arr: Array2<f64> = Array2::random((shape[0], shape[1]), Uniform::new(0., 1.));
116        Tensor::new(arr)
117    }
118
119    /// Create a tensor from a `f64`
120    pub fn from_f64(val: f64) -> Tensor {
121        Tensor::new(array![[val]])
122    }
123
124    /// Create a tensor of shape filled with ones
125    pub fn ones(shape: [usize; 2]) -> Tensor {
126        let arr: Array2<f64> = Array2::ones((shape[0], shape[1]));
127        Tensor::new(arr)
128    }
129
130    /// Create a tensor of shape filled with zeros
131    pub fn zeros(shape: [usize; 2]) -> Tensor {
132        let arr: Array2<f64> = Array2::zeros((shape[0], shape[1]));
133        Tensor::new(arr)
134    }
135
136    /// Update tensor value given its derivative
137    /// and a learning rate; useful for machine learning
138    /// applications
139    pub fn update(&self, lr: f64) {
140        let mut data = self.inner_mut();
141        let grad = data.grad.clone();
142        data.data.scaled_add(-lr, &grad);
143    }
144
145    /// Create a tensor from a range
146    pub fn arange<I: Iterator<Item = i32>>(range: I, shape: [usize; 2]) -> Tensor {
147        let arr = Array::from_iter(range)
148            .mapv(|el| el as f64)
149            .into_shape((shape[0], shape[1]))
150            .unwrap();
151        Tensor::new(arr)
152    }
153
154    /// Create a tensor containing a linearly-spaced
155    /// interval
156    pub fn linspace(start: f64, end: f64, num: usize) -> Tensor {
157        let arr = Array::linspace(start, end, num);
158        let arr_reshaped = arr.into_shape((num, 1)).unwrap();
159        Tensor::new(arr_reshaped)
160    }
161
162    /// Change the shape of a tensor and return a new tensor
163    pub fn reshape(&mut self, shape: [usize; 2]) -> Tensor {
164        Tensor::new(self.data().clone().into_shape(shape).unwrap())
165    }
166
167    /// Get the number of elements in a tensor
168    pub fn len(&self) -> usize {
169        self.data().len()
170    }
171
172    /// Find the sum of a tensor
173    pub fn sum(&self) -> Tensor {
174        let sum = self.data().sum();
175        let out = Tensor::from_f64(sum);
176        out.inner_mut().prev = vec![self.clone()];
177        out.inner_mut().op = Some(String::from("sum"));
178        out.inner_mut().backward = Some(|value: &TensorData| {
179            value.prev[0].grad_mut().scaled_add(1.0, &value.grad);
180        });
181        out
182    }
183
184    /// Find the mean of a tensor
185    pub fn mean(&self) -> Tensor {
186        (1.0 / self.data().len() as f64) * self.sum()
187    }
188
189    /// Exponential function for tensors
190    pub fn exp(&self) -> Tensor {
191        let exp_array = self.borrow().data.mapv(|val| val.exp());
192        let out = Tensor::new(exp_array);
193        out.inner_mut().prev = vec![self.clone()];
194        out.inner_mut().op = Some(String::from("exp"));
195        out.inner_mut().backward = Some(|value: &TensorData| {
196            let prev = value.prev[0].borrow().data.clone();
197            value.prev[0]
198                .grad_mut()
199                .scaled_add(1.0, &prev.mapv(|val| val.exp()));
200        });
201        out
202    }
203
204    /// ReLU function for tensors
205    pub fn relu(&self) -> Tensor {
206        let relu_array = self.data().mapv(|val| val.max(0.0));
207        let out = Tensor::new(relu_array);
208        out.inner_mut().prev = vec![self.clone()];
209        out.inner_mut().op = Some(String::from("ReLU"));
210        out.inner_mut().backward = Some(|value: &TensorData| {
211            let dv = value.prev[0]
212                .data()
213                .mapv(|x| if x > 0.0 { 1.0 } else { 0.0 });
214            value.prev[0].grad_mut().scaled_add(1.0, &dv);
215        });
216        out
217    }
218
219    /// Power function for tensors (not recommended as it breaks easily)
220    pub fn pow(&self, power: f64) -> Tensor {
221        let pow_array = self.data().mapv(|val| val.powf(power));
222        let out = Tensor::new(pow_array);
223        out.inner_mut().prev = vec![self.clone(), Tensor::from_f64(power)];
224        out.inner_mut().op = Some(String::from("^"));
225        out.inner_mut().backward = Some(|value: &TensorData| {
226            let base_vec = value.prev[0]
227                .data()
228                .mapv(|val| val.powf(value.prev[1].data()[[0, 0]] - 1.0));
229            value.prev[0].grad_mut().scaled_add(
230                1.0,
231                &(value.prev[1].data().deref() * base_vec * value.grad.clone()),
232            );
233        });
234        out
235    }
236
237    /// Sigmoid function for tensors
238    pub fn sigmoid(&self) -> Tensor {
239        let sigmoid_array = self.borrow().data.mapv(|val| 1.0 / (1.0 + (-val).exp()));
240        let out = Tensor::new(sigmoid_array);
241        out.inner_mut().prev = vec![self.clone()];
242        out.inner_mut().op = Some(String::from("exp"));
243        out.inner_mut().backward = Some(|value: &TensorData| {
244            let prev = value.prev[0].borrow().data.clone();
245            let exp_array = prev.mapv(|val| val.exp() / (1.0 + val.exp()).powf(2.0));
246            value.prev[0].inner_mut().grad.scaled_add(1.0, &exp_array);
247        });
248        out
249    }
250
251    /// Tensor matrix multiplication
252    pub fn matmul(&self, rhs: &Tensor) -> Tensor {
253        let a_shape = self.shape();
254        let b_shape = rhs.shape();
255        if a_shape.1 != b_shape.0 {
256            error!("You are attempting to matrix-multiply two matrices of size {} x {} and {} x {}. These shapes are not compatible.", a_shape.0, a_shape.1, b_shape.0, b_shape.1);
257        }
258        let res: Array2<f64> = self.data().dot(rhs.data().deref());
259        let out = Tensor::new(res);
260        out.inner_mut().prev = vec![self.clone(), rhs.clone()];
261        out.inner_mut().op = Some(String::from("matmul"));
262        out.inner_mut().backward = Some(|value: &TensorData| {
263            let da = value.grad.dot(&value.prev[1].data().t());
264            let db = value.prev[0].data().t().dot(&value.grad);
265            value.prev[0].grad_mut().scaled_add(1.0, &da);
266            value.prev[1].grad_mut().scaled_add(1.0, &db);
267        });
268        out
269    }
270
271    /// Get the underlying `TensorData` of a tensor
272    pub fn inner(&self) -> Ref<TensorData> {
273        (*self.0).borrow()
274    }
275
276    /// Get the underlying `TensorData` of a tensor
277    /// as mutable
278    pub fn inner_mut(&self) -> RefMut<TensorData> {
279        (*self.0).borrow_mut()
280    }
281
282    /// Get the underlying data NdArray of a tensor
283    pub fn data(&self) -> impl Deref<Target = Array2<f64>> + '_ {
284        Ref::map((*self.0).borrow(), |mi| &mi.data)
285    }
286
287    /// Get the underlying data NdArray of a tensor
288    /// as mutable
289    pub fn data_mut(&self) -> impl DerefMut<Target = Array2<f64>> + '_ {
290        RefMut::map((*self.0).borrow_mut(), |mi| &mut mi.data)
291    }
292
293    /// Find the gradient of a tensor
294    /// Remember to call `backward()` first!
295    pub fn grad(&self) -> impl Deref<Target = Array2<f64>> + '_ {
296        Ref::map((*self.0).borrow(), |mi| &mi.grad)
297    }
298
299    /// Get the gradient of a tensor as mutable
300    /// Remember to call `backward()` first!
301    pub fn grad_mut(&self) -> impl DerefMut<Target = Array2<f64>> + '_ {
302        RefMut::map((*self.0).borrow_mut(), |mi| &mut mi.grad)
303    }
304
305    /// Zero the gradient of a tensor
306    pub fn zero_grad(&self) {
307        self.grad_mut().fill(0.0);
308    }
309
310    /// Perform backpropagation on a tensor
311    pub fn backward(&self) {
312        let mut topo: Vec<Tensor> = vec![];
313        let mut visited: HashSet<Tensor> = HashSet::new();
314        self._build_topo(&mut topo, &mut visited);
315        topo.reverse();
316
317        self.grad_mut().fill(1.0);
318        for v in topo {
319            if let Some(backprop) = v.borrow().backward {
320                backprop(&v.borrow());
321            }
322        }
323    }
324
325    fn _build_topo(&self, topo: &mut Vec<Tensor>, visited: &mut HashSet<Tensor>) {
326        if visited.insert(self.clone()) {
327            self.borrow().prev.iter().for_each(|child| {
328                child._build_topo(topo, visited);
329            });
330            topo.push(self.clone());
331        }
332    }
333
334    // Thanks to: https://stackoverflow.com/questions/76727378/how-to-implement-iter-for-a-type-that-wraps-an-ndarray
335    /// Iterate over elements of a tensor
336    pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
337        let data = self.data();
338        (0..data.shape()[0]).map(move |i| {
339            let el = data.index_axis(Axis(0), i);
340            let reshaped_and_cloned_el = el
341                .into_shape((el.shape()[0], 1))
342                .unwrap()
343                .mapv(|el| el.clone());
344            Tensor::new(reshaped_and_cloned_el)
345        })
346    }
347}
348
349impl Iterator for Tensor {
350    type Item = Tensor;
351    fn next(&mut self) -> Option<Self::Item> {
352        Some(self.iter().next().unwrap())
353    }
354}
355
356// TODO: better printing of tensors
357impl Debug for Tensor {
358    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359        write!(f, "{:?}", self.data().deref(),)
360    }
361}
362
363// macro to automatically write impls for
364// general tensor elementwise binary ops
365// for basic math
366macro_rules! impl_binary_op {
367    [$trait:ident, $op_name:ident, $op:tt] => {
368        impl $trait for Tensor {
369            type Output = Tensor;
370
371            fn $op_name(self, rhs: Tensor) -> Self::Output {
372                &self $op &rhs
373            }
374        }
375
376        impl $trait<f64> for &Tensor {
377            type Output = Tensor;
378
379            fn $op_name(self, rhs: f64) -> Self::Output {
380                self $op &Tensor::from_f64(rhs)
381            }
382        }
383
384        impl $trait<f64> for Tensor {
385            type Output = Tensor;
386
387            fn $op_name(self, rhs: f64) -> Self::Output {
388                &self $op rhs
389            }
390        }
391
392         impl $trait<&Tensor> for f64 {
393            type Output = Tensor;
394
395            fn $op_name(self, rhs: &Tensor) -> Self::Output {
396                &Tensor::from_f64(self) $op rhs
397            }
398        }
399
400        impl $trait<Tensor> for f64 {
401            type Output = Tensor;
402
403            fn $op_name(self, rhs: Tensor) -> Self::Output {
404                self $op &rhs
405            }
406        }
407    };
408
409    [$trait:ident, $op_name:ident, $op:tt, $update_grad:expr] => {
410        impl $trait for &Tensor {
411            type Output = Tensor;
412
413            fn $op_name(self, rhs: &Tensor) -> Self::Output {
414                let out = Tensor::new(self.data().deref() $op rhs.data().deref());
415                out.inner_mut().prev = vec![self.clone(), rhs.clone()];
416                out.inner_mut().op = Some(stringify!($op_name).to_string());
417                out.inner_mut().backward = Some(|value: &TensorData| {
418                    let (dv1, dv2) = $update_grad(&value.grad, value.prev[0].data().deref(), value.prev[1].data().deref());
419
420                    let dv1 = match value.prev[0].grad().dim() {
421                        (1, 1) => arr2(&[[dv1.sum()]]),
422                        (1, n) => dv1.sum_axis(Axis(0)).into_shape((1, n)).unwrap(),
423                        (n, 1) => dv1.sum_axis(Axis(1)).into_shape((n, 1)).unwrap(),
424                        (_, _) => dv1,
425                    };
426                    let dv2 = match value.prev[1].grad().dim() {
427                        (1, 1) => arr2(&[[dv2.sum()]]),
428                        (1, n) => dv2.sum_axis(Axis(0)).into_shape((1, n)).unwrap(),
429                        (n, 1) => dv2.sum_axis(Axis(1)).into_shape((n, 1)).unwrap(),
430                        (_, _) => dv2,
431                    };
432
433                    value.prev[0].grad_mut().scaled_add(1.0, &dv1);
434                    value.prev[1].grad_mut().scaled_add(1.0, &dv2);
435                });
436                out
437            }
438        }
439
440        impl_binary_op![$trait, $op_name, $op];
441    };
442}
443
444// similar macro as one above, just for
445// assignment ops for basic math
446// e.g. +=, -=, *=, /=
447
448macro_rules! impl_assignment_op {
449    [$trait:ident, $op_name:ident, $op:tt] => {
450        impl $trait for Tensor {
451            fn $op_name(&mut self, rhs: Tensor) {
452               *self = self.clone() $op rhs;
453            }
454        }
455
456        impl $trait<f64> for Tensor {
457            fn $op_name(&mut self, rhs: f64) {
458               *self = self.clone() $op rhs;
459            }
460        }
461    };
462
463    // [$trait:ident, $op_name:ident, $op:tt, $update_grad:expr] => {
464    //     impl $trait for &Tensor {
465    //         fn $op_name(&mut self, rhs: &Tensor) {
466    //             self.inner_mut().prev = vec![self.clone(), rhs.clone()];
467    //             self.inner_mut().op = Some(stringify!($op_name).to_string());
468    //             self.inner_mut().backward = Some(|value: &TensorData| {
469    //                 let (dv1, dv2) = $update_grad(&value.grad, value.prev[0].data().deref(), value.prev[1].data().deref());
470
471    //                 let dv1 = match value.prev[0].grad().dim() {
472    //                     (1, 1) => arr2(&[[dv1.sum()]]),
473    //                     (1, n) => dv1.sum_axis(Axis(0)).into_shape((1, n)).unwrap(),
474    //                     (n, 1) => dv1.sum_axis(Axis(1)).into_shape((n, 1)).unwrap(),
475    //                     (_, _) => dv1,
476    //                 };
477    //                 let dv2 = match value.prev[1].grad().dim() {
478    //                     (1, 1) => arr2(&[[dv2.sum()]]),
479    //                     (1, n) => dv2.sum_axis(Axis(0)).into_shape((1, n)).unwrap(),
480    //                     (n, 1) => dv2.sum_axis(Axis(1)).into_shape((n, 1)).unwrap(),
481    //                     (_, _) => dv2,
482    //                 };
483
484    //                 value.prev[0].grad_mut().scaled_add(1.0, &dv1);
485    //                 value.prev[1].grad_mut().scaled_add(1.0, &dv2);
486    //             });
487    //         }
488    //     }
489
490    //     impl_assignment_op![$trait, $op_name, $op];
491    // };
492}
493
494impl_binary_op![Add, add, +, |grad, _a, _b| { (grad * 1.0, grad * 1.0) }];
495impl_binary_op![Sub, sub, -, |grad, _a, _b| { (grad * 1.0, grad * -1.0) }];
496impl_binary_op![Mul, mul, *, |grad, a, b| { (grad * b, grad * a) }];
497impl_binary_op![Div, div, /, |grad, a, b| { (grad * 1.0 / b, grad * -1.0 * a / (b * b)) }];
498
499impl_assignment_op![AddAssign, add_assign, +];
500impl_assignment_op![SubAssign, sub_assign, -];
501impl_assignment_op![MulAssign, mul_assign, *];
502impl_assignment_op![DivAssign, div_assign, /];