linalg_rs/matrix/
mod.rs

1//! Makin with matrices in rust easier!
2//!
3//! For now, only basic operations are allowed, but more are to be added
4//!
5//! This file is sub 1500 lines and acts as the core file
6
7mod helper;
8mod optim;
9
10use helper::*;
11
12use serde::{Deserialize, Serialize};
13use std::{
14    error::Error,
15    fmt::{Debug, Display},
16    fs,
17    marker::PhantomData,
18    ops::Div,
19    str::FromStr,
20};
21
22use itertools::{iproduct, Itertools};
23use num_traits::{pow, real::Real, sign::abs, Float};
24use rand::Rng;
25use rayon::prelude::*;
26use std::iter::Sum;
27
28use crate::{at, LinAlgFloats, MatrixElement, MatrixError, SparseMatrix};
29
30/// Shape represents the dimension size
31/// of the matrix as a tuple of usize
32pub type Shape = (usize, usize);
33
34/// Helper method to swap to usizes
35
36#[derive(Clone, PartialEq, PartialOrd, Debug, Serialize, Deserialize)]
37/// General dense matrix
38pub struct Matrix<'a, T>
39where
40    T: MatrixElement,
41    <T as FromStr>::Err: Error + 'static,
42    Vec<T>: IntoParallelIterator,
43    Vec<&'a T>: IntoParallelRefIterator<'a>,
44{
45    /// Vector containing all data
46    data: Vec<T>,
47    /// Number of rows
48    pub nrows: usize,
49    /// Number of columns
50    pub ncols: usize,
51    _lifetime: PhantomData<&'a T>,
52}
53
54impl<'a, T> Error for Matrix<'a, T>
55where
56    T: MatrixElement,
57    <T as FromStr>::Err: Error + 'static,
58    Vec<T>: IntoParallelIterator,
59    Vec<&'a T>: IntoParallelRefIterator<'a>,
60{
61}
62
63unsafe impl<'a, T> Send for Matrix<'a, T>
64where
65    T: MatrixElement,
66    <T as FromStr>::Err: Error + 'static,
67    Vec<T>: IntoParallelIterator,
68    Vec<&'a T>: IntoParallelRefIterator<'a>,
69{
70}
71
72unsafe impl<'a, T> Sync for Matrix<'a, T>
73where
74    T: MatrixElement,
75    <T as FromStr>::Err: Error + 'static,
76    Vec<T>: IntoParallelIterator,
77    Vec<&'a T>: IntoParallelRefIterator<'a>,
78{
79}
80
81impl<'a, T> FromStr for Matrix<'a, T>
82where
83    T: MatrixElement,
84    <T as FromStr>::Err: Error + 'static,
85    Vec<T>: IntoParallelIterator,
86    Vec<&'a T>: IntoParallelRefIterator<'a>,
87{
88    type Err = anyhow::Error;
89    fn from_str(s: &str) -> Result<Self, Self::Err> {
90        // Parse the input string and construct the matrix dynamically
91        let v: Vec<T> = s
92            .trim()
93            .lines()
94            .map(|l| {
95                l.split_whitespace()
96                    .map(|num| num.parse::<T>().unwrap())
97                    .collect::<Vec<T>>()
98            })
99            .collect::<Vec<Vec<T>>>()
100            .into_iter()
101            .flatten()
102            .collect();
103
104        let rows = s.trim().lines().count();
105        let cols = s.trim().lines().nth(0).unwrap().split_whitespace().count();
106
107        Ok(Self::new(v, (rows, cols)).unwrap())
108    }
109}
110
111impl<'a, T> Display for Matrix<'a, T>
112where
113    T: MatrixElement,
114    <T as FromStr>::Err: Error + 'static,
115    Vec<T>: IntoParallelIterator,
116    Vec<&'a T>: IntoParallelRefIterator<'a>,
117{
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(f, "[");
120
121        // Large matrices
122        if self.nrows > 10 || self.ncols > 10 {
123            write!(f, "...");
124        }
125
126        for i in 0..self.nrows {
127            for j in 0..self.ncols {
128                if i == 0 {
129                    write!(f, "{:.4} ", self.get(i, j).unwrap());
130                } else {
131                    write!(f, " {:.4}", self.get(i, j).unwrap());
132                }
133            }
134            // Print ] on same line if youre at the end
135            if i == self.nrows - 1 {
136                break;
137            }
138            writeln!(f);
139        }
140        writeln!(f, "], dtype={}", std::any::type_name::<T>())
141    }
142}
143
144impl<'a, T> Default for Matrix<'a, T>
145where
146    T: MatrixElement,
147    <T as FromStr>::Err: Error + 'static,
148    Vec<T>: IntoParallelIterator,
149    Vec<&'a T>: IntoParallelRefIterator<'a>,
150{
151    /// Represents a default identity matrix
152    ///
153    /// # Examples
154    ///
155    /// ```
156    /// use sukker::Matrix;
157    ///
158    /// let matrix: Matrix<f32> = Matrix::default();
159    ///
160    /// assert_eq!(matrix.size(), 9);
161    /// assert_eq!(matrix.shape(), (3,3));
162    /// ```
163    fn default() -> Self {
164        Self::eye(3)
165    }
166}
167
168/// Printer functions for the matrix
169impl<'a, T> Matrix<'a, T>
170where
171    T: MatrixElement,
172    <T as FromStr>::Err: Error + 'static,
173    Vec<T>: IntoParallelIterator,
174    Vec<&'a T>: IntoParallelRefIterator<'a>,
175{
176    /// Prints out the matrix with however many decimals you want
177    ///
178    /// # Examples
179    ///
180    /// ```
181    /// use sukker::Matrix;
182    ///
183    /// let matrix: Matrix<i32> = Matrix::eye(2);
184    /// matrix.print(4);
185    ///
186    /// ```
187    pub fn print(&self, decimals: usize) {
188        print!("[");
189
190        // Large matrices
191        if self.nrows > 10 || self.ncols > 10 {
192            print!("...");
193        }
194
195        for i in 0..self.nrows {
196            for j in 0..self.ncols {
197                if i == 0 {
198                    print!(
199                        "{val:.dec$} ",
200                        dec = decimals,
201                        val = self.get(i, j).unwrap()
202                    );
203                } else {
204                    print!(
205                        " {val:.dec$}",
206                        dec = decimals,
207                        val = self.get(i, j).unwrap()
208                    );
209                }
210            }
211            // Print ] on same line if youre at the end
212            if i == self.nrows - 1 {
213                break;
214            }
215            println!();
216        }
217        println!("], dtype={}", std::any::type_name::<T>());
218    }
219
220    /// Calculates sparsity of a given Matrix
221    ///
222    /// Examples:
223    ///
224    /// ```
225    /// use sukker::Matrix;
226    ///
227    /// let mat: Matrix<f32> = Matrix::eye(2);
228    ///
229    /// assert_eq!(mat.sparsity(), 0.5);
230    /// ```
231    #[inline(always)]
232    pub fn sparsity(&'a self) -> f64 {
233        self.count_where(|&e| e == T::zero()) as f64 / self.size() as f64
234    }
235
236    /// Returns the shape of a matrix represented as  
237    /// (usize, usize)
238    ///
239    /// Examples:
240    ///
241    /// ```
242    /// use sukker::Matrix;
243    ///
244    /// let mat: Matrix<f32> = Matrix::eye(4);
245    ///
246    /// assert_eq!(mat.shape(), (4,4));
247    /// ```
248    pub fn shape(&self) -> Shape {
249        (self.nrows, self.ncols)
250    }
251}
252
253/// Implementations of all creatins of matrices
254impl<'a, T> Matrix<'a, T>
255where
256    T: MatrixElement,
257    <T as FromStr>::Err: Error,
258    Vec<T>: IntoParallelIterator,
259    Vec<&'a T>: IntoParallelRefIterator<'a>,
260{
261    /// Creates a new matrix from a vector and the shape you want.
262    /// Will default init if it does not work
263    ///
264    /// # Examples
265    ///
266    /// ```
267    /// use sukker::Matrix;
268    ///
269    /// let matrix = Matrix::new(vec![1.0,2.0,3.0,4.0], (2,2)).unwrap();
270    ///
271    /// assert_eq!(matrix.size(), 4);
272    /// assert_eq!(matrix.shape(), (2,2));
273    /// ```
274    pub fn new(data: Vec<T>, shape: Shape) -> Result<Self, MatrixError> {
275        if shape.0 * shape.1 != data.len() {
276            return Err(MatrixError::MatrixCreationError.into());
277        }
278
279        Ok(Self {
280            data,
281            nrows: shape.0,
282            ncols: shape.1,
283            _lifetime: PhantomData::default(),
284        })
285    }
286
287    /// Initializes a matrix with the same value
288    /// given from parameter 'value'
289    ///
290    /// # Examples
291    ///
292    /// ```
293    /// use sukker::Matrix;
294    ///
295    /// let matrix = Matrix::init(4f32, (1,2));
296    ///
297    /// assert_eq!(matrix.get_vec(), vec![4f32,4f32]);
298    /// assert_eq!(matrix.shape(), (1,2));
299    /// ```
300    pub fn init(value: T, shape: Shape) -> Self {
301        Self::from_shape(value, shape)
302    }
303
304    /// Returns an eye matrix which for now is the same as the
305    /// identity matrix
306    ///
307    /// # Examples
308    ///
309    /// ```
310    /// use sukker::Matrix;
311    ///
312    /// let matrix: Matrix<f32> = Matrix::eye(2);
313    ///
314    /// assert_eq!(matrix.get_vec(), vec![1f32, 0f32, 0f32, 1f32]);
315    /// assert_eq!(matrix.shape(), (2,2));
316    /// ```
317    pub fn eye(size: usize) -> Self {
318        let mut data: Vec<T> = vec![T::zero(); size * size];
319
320        (0..size).for_each(|i| data[i * size + i] = T::one());
321
322        // Safe to do since the library is setting the size
323        Self::new(data, (size, size)).unwrap()
324    }
325
326    /// Produce an identity matrix in the same shape as
327    /// an already existent matrix
328    ///
329    /// Examples
330    ///
331    /// ```
332    ///
333    ///
334    /// ```
335    pub fn eye_like(matrix: &Self) -> Self {
336        Self::eye(matrix.nrows)
337    }
338
339    /// Identity is same as eye, just for nerds
340    ///
341    /// # Examples
342    ///
343    /// ```
344    /// use sukker::Matrix;
345    ///
346    /// let matrix: Matrix<i64> = Matrix::identity(2);
347    ///
348    /// assert_eq!(matrix.get_vec(), vec![1i64, 0i64, 0i64, 1i64]);
349    /// assert_eq!(matrix.shape(), (2,2));
350    /// ```
351    pub fn identity(size: usize) -> Self {
352        Self::eye(size)
353    }
354
355    /// Tries to create a matrix from a slize and shape
356    ///
357    /// # Examples
358    ///
359    /// ```
360    /// use sukker::Matrix;
361    ///
362    /// let s = vec![1f32, 2f32, 3f32, 4f32];
363    /// let matrix = Matrix::from_slice(&s, (4,1)).unwrap();
364    ///
365    /// assert_eq!(matrix.shape(), (4,1));
366    /// ```
367    pub fn from_slice(arr: &[T], shape: Shape) -> Result<Self, MatrixError> {
368        if shape.0 * shape.1 != arr.len() {
369            return Err(MatrixError::MatrixCreationError.into());
370        }
371
372        Ok(Self::new(arr.to_owned(), shape).unwrap())
373    }
374
375    /// Creates a matrix where all values are 0.
376    /// All sizes are based on a shape
377    ///
378    /// # Examples
379    ///
380    /// ```
381    /// use sukker::Matrix;
382    ///
383    /// let matrix: Matrix<f64> = Matrix::zeros((4,1));
384    ///
385    /// assert_eq!(matrix.shape(), (4,1));
386    /// assert_eq!(matrix.get(0,0).unwrap(), 0f64);
387    /// assert_eq!(matrix.size(), 4);
388    /// ```
389    pub fn zeros(shape: Shape) -> Self {
390        Self::from_shape(T::zero(), shape)
391    }
392
393    /// Creates a matrix where all values are 1.
394    /// All sizes are based on a shape
395    ///
396    /// # Examples
397    ///
398    /// ```
399    /// use sukker::Matrix;
400    ///
401    /// let matrix: Matrix<f64> = Matrix::ones((4,1));
402    ///
403    /// assert_eq!(matrix.shape(), (4,1));
404    /// assert_eq!(matrix.get(0,0).unwrap(), 1f64);
405    /// assert_eq!(matrix.size(), 4);
406    /// ```
407    pub fn ones(shape: Shape) -> Self {
408        Self::from_shape(T::one(), shape)
409    }
410
411    /// Creates a matrix where all values are 0.
412    /// All sizes are based on an already existent matrix
413    ///
414    /// # Examples
415    ///
416    /// ```
417    /// use sukker::Matrix;
418    ///
419    /// let matrix1: Matrix<i8> = Matrix::default();
420    /// let matrix2 = Matrix::zeros_like(&matrix1);
421    ///
422    /// assert_eq!(matrix2.shape(), matrix1.shape());
423    /// assert_eq!(matrix2.get(0,0).unwrap(), 0i8);
424    /// ```
425    pub fn zeros_like(other: &Self) -> Self {
426        Self::from_shape(T::zero(), other.shape())
427    }
428
429    /// Creates a matrix where all values are 1.
430    /// All sizes are based on an already existent matrix
431    ///
432    /// # Examples
433    ///
434    /// ```
435    /// use sukker::Matrix;
436    ///
437    /// let matrix1: Matrix<i64> = Matrix::default();
438    /// let matrix2 = Matrix::ones_like(&matrix1);
439    ///
440    /// assert_eq!(matrix2.shape(), matrix1.shape());
441    /// assert_eq!(1i64, matrix2.get(0,0).unwrap());
442    /// ```
443    pub fn ones_like(other: &Self) -> Self {
444        Self::from_shape(T::one(), other.shape())
445    }
446
447    /// Creates a matrix where all values are random between 0 and 1.
448    /// All sizes are based on an already existent matrix
449    ///
450    /// # Examples
451    ///
452    /// ```
453    /// use sukker::Matrix;
454    ///
455    /// let matrix1: Matrix<f32> = Matrix::default();
456    /// let matrix2 = Matrix::random_like(&matrix1);
457    ///
458    /// assert_eq!(matrix1.shape(), matrix2.shape());
459    ///
460    ///
461    /// ```
462    pub fn random_like(matrix: &Self) -> Self {
463        Self::randomize_range(T::zero(), T::one(), matrix.shape())
464    }
465
466    /// Creates a matrix where all values are random between start..=end.
467    /// Shape in new array is given through parameter 'shape'
468    ///
469    /// # Examples
470    ///
471    /// ```
472    /// use sukker::Matrix;
473    ///
474    /// let matrix = Matrix::randomize_range(1f32, 2f32, (2,3));
475    /// let elem = matrix.get(1,1).unwrap();
476    ///
477    /// assert_eq!(matrix.shape(), (2,3));
478    /// //assert!(elem >= 1f32 && 2f32 <= elem);
479    /// ```
480    pub fn randomize_range(start: T, end: T, shape: Shape) -> Self {
481        let mut rng = rand::thread_rng();
482
483        let (rows, cols) = shape;
484
485        let len: usize = rows * cols;
486
487        let data: Vec<T> = (0..len).map(|_| rng.gen_range(start..=end)).collect();
488
489        // Safe because shape doesn't have to match data from a user
490        Self::new(data, shape).unwrap()
491    }
492
493    /// Creates a matrix where all values are random between 0..=1.
494    /// Shape in new array is given through parameter 'shape'
495    ///
496    /// # Examples
497    ///
498    /// ```
499    /// use sukker::Matrix;
500    ///
501    /// let matrix: Matrix<f64> = Matrix::randomize((2,3));
502    ///
503    /// assert_eq!(matrix.shape(), (2,3));
504    /// ```
505    pub fn randomize(shape: Shape) -> Self {
506        Self::randomize_range(T::zero(), T::one(), shape)
507    }
508
509    /// Parses from file, but will return a default matrix if nothing is given
510    ///
511    /// # Examples
512    ///
513    /// ```
514    /// use sukker::Matrix;
515    ///
516    /// // let m: Matrix<f32> = Matrix::from_file("../../test.txt").unwrap();
517    ///
518    /// // m.print(4);
519    /// ```
520    pub fn from_file(path: &'static str) -> Result<Self, MatrixError> {
521        let data =
522            fs::read_to_string(path).map_err(|_| MatrixError::MatrixFileReadError(path).into())?;
523
524        data.parse::<Self>()
525            .map_err(|_| MatrixError::MatrixParseError.into())
526    }
527
528    /// Constructs a new dense matrix from a sparse one.
529    ///
530    /// This transfesrs ownership as well!
531    ///
532    /// Examples
533    ///
534    /// ```
535    /// use sukker::{Matrix, SparseMatrix};
536    ///
537    /// let sparse = SparseMatrix::<i32>::eye(3);
538    ///
539    /// let matrix = Matrix::from_sparse(sparse);
540    ///
541    /// assert_eq!(matrix.shape(), (3,3));
542    /// assert_eq!(matrix.at(0,0), 1);
543    /// ```
544    pub fn from_sparse(sparse: SparseMatrix<'a, T>) -> Self {
545        let mut mat = Self::zeros(sparse.shape());
546
547        for (&idx, &val) in sparse.data.iter() {
548            mat.set(val, idx);
549        }
550
551        mat
552    }
553
554    /// Helper function to create matrices
555    fn from_shape(value: T, shape: Shape) -> Self {
556        let (rows, cols) = shape;
557
558        let len: usize = rows * cols;
559
560        let data = vec![value; len];
561
562        Self::new(data, shape).unwrap()
563    }
564}
565
566/// Enum for specifying which dimension / axis to work with
567pub enum Dimension {
568    /// Row is defined as 0
569    Row = 0,
570    /// Col is defined as 1
571    Col = 1,
572}
573
574/// Regular matrix methods that are not operating math on them
575impl<'a, T> Matrix<'a, T>
576where
577    T: MatrixElement + Div<Output = T> + Sum<T>,
578    <T as FromStr>::Err: Error + 'static,
579    Vec<T>: IntoParallelIterator,
580    Vec<&'a T>: IntoParallelRefIterator<'a>,
581{
582    /// Reshapes a matrix if possible.
583    /// If the shapes don't match up, the old shape will be retained
584    ///
585    /// # Examples
586    ///
587    /// ```
588    /// use sukker::Matrix;
589    ///
590    /// let mut matrix = Matrix::init(10.5, (2,3));
591    /// matrix.reshape(3,2);
592    ///
593    /// assert_eq!(matrix.shape(), (3,2));
594    /// ```
595    pub fn reshape(&mut self, nrows: usize, ncols: usize) {
596        if nrows * ncols != self.size() {
597            eprintln!("Err: Can not reshape.. Keeping old dimensions for now");
598            return;
599        }
600
601        self.nrows = nrows;
602        self.ncols = ncols;
603    }
604
605    /// Get the total size of the matrix
606    ///
607    /// # Examples
608    ///
609    /// ```
610    /// use sukker::Matrix;
611    ///
612    /// let matrix = Matrix::init(10.5, (2,3));
613    ///
614    /// assert_eq!(matrix.size(), 6);
615    /// ```
616    pub fn size(&self) -> usize {
617        self.nrows * self.ncols
618    }
619
620    ///  Gets element based on is and js
621    ///
622    /// # Examples
623    ///
624    /// ```
625    /// use sukker::Matrix;
626    ///
627    /// let matrix = Matrix::init(10.5f32, (2,3));
628    ///
629    /// assert_eq!(matrix.get(0,1).unwrap(), 10.5f32);
630    /// ```
631    pub fn get(&self, i: usize, j: usize) -> Option<T> {
632        let idx = at!(i, j, self.ncols);
633
634        if idx >= self.size() {
635            return None;
636        }
637
638        Some(self.at(i, j))
639    }
640
641    ///  Gets element based on is and js, but will
642    ///  panic if indexes are out of range.
643    ///
644    /// # Examples
645    ///
646    /// ```
647    /// use sukker::Matrix;
648    ///
649    /// let val = 10.5;
650    ///
651    /// let matrix = Matrix::init(val, (2,3));
652    ///
653    /// assert_eq!(matrix.at(1,2), val);
654    /// ```
655    #[inline(always)]
656    pub fn at(&self, i: usize, j: usize) -> T {
657        self.data[at!(i, j, self.ncols)]
658    }
659
660    ///  Gets a piece of the matrix out as a vector
661    ///
662    ///  If some indeces are out of bounds, the vec up until that point
663    ///  will be returned
664    ///
665    /// # Examples
666    ///
667    /// ```
668    /// use sukker::Matrix;
669    ///
670    /// let matrix = Matrix::init(10.5, (4,4));
671    /// let slice = matrix.get_vec_slice((1,1), (2,2));
672    ///
673    /// assert_eq!(slice, vec![10.5,10.5,10.5,10.5]);
674    /// ```
675    pub fn get_vec_slice(&self, start_idx: Shape, size: Shape) -> Vec<T> {
676        let (start_row, start_col) = start_idx;
677        let (dx, dy) = size;
678
679        iproduct!(start_row..start_row + dy, start_col..start_col + dx)
680            .filter_map(|(i, j)| self.get(i, j))
681            .collect()
682    }
683
684    /// Gets you the whole entire matrix as a vector
685    ///
686    /// # Examples
687    ///
688    /// ```
689    /// use sukker::Matrix;
690    ///
691    /// let matrix = Matrix::init(10.5, (4,4));
692    /// let slice = matrix.get_vec_slice((1,1), (2,2));
693    ///
694    /// assert_eq!(slice, vec![10.5,10.5,10.5,10.5]);
695    /// ```
696    pub fn get_vec(&self) -> Vec<T> {
697        self.data.clone()
698    }
699
700    ///  Gets a piece of the matrix out as a matrix
701    ///
702    ///  If some indeces are out of bounds, unlike `get_vec_slice`
703    ///  this function will return an IndexOutOfBoundsError
704    ///  and will not return data
705    ///
706    /// # Examples
707    ///
708    /// ```
709    /// use sukker::Matrix;
710    ///
711    /// let matrix = Matrix::init(10.5, (4,4));
712    /// let sub_matrix = matrix.get_sub_matrix((1,1), (2,2)).unwrap();
713    ///
714    /// assert_eq!(sub_matrix.get_vec(), vec![10.5,10.5,10.5,10.5]);
715    /// ```
716    pub fn get_sub_matrix(&self, start_idx: Shape, size: Shape) -> Result<Self, MatrixError> {
717        let (start_row, start_col) = start_idx;
718        let (dx, dy) = size;
719
720        let data = iproduct!(start_row..start_row + dy, start_col..start_col + dx)
721            .filter_map(|(i, j)| self.get(i, j))
722            .collect();
723
724        return match Self::new(data, size) {
725            Ok(a) => Ok(a),
726            Err(_) => Err(MatrixError::MatrixIndexOutOfBoundsError.into()),
727        };
728    }
729
730    /// Concat two mtrices on a dimension
731    ///
732    /// # Examples
733    ///
734    /// ```
735    /// use sukker::Matrix;
736    /// use sukker::Dimension;
737    ///
738    /// let matrix = Matrix::init(10.5, (4,4));
739    /// let matrix2 = Matrix::init(10.5, (1,4));
740    ///
741    /// let res = matrix.concat(&matrix2, Dimension::Row).unwrap();
742    ///
743    /// assert_eq!(res.shape(), (5,4));
744    /// ```
745    pub fn concat(&self, other: &Self, dim: Dimension) -> Result<Self, MatrixError> {
746        match dim {
747            Dimension::Row => {
748                if self.ncols != other.ncols {
749                    return Err(MatrixError::MatrixConcatinationError.into());
750                }
751
752                let mut new_data = self.data.clone();
753
754                new_data.extend(other.data.iter());
755
756                let nrows = self.nrows + other.nrows;
757                let shape = (nrows, self.ncols);
758
759                return Ok(Self::new(new_data, shape).unwrap());
760            }
761
762            Dimension::Col => {
763                if self.nrows != other.nrows {
764                    return Err(MatrixError::MatrixConcatinationError.into());
765                }
766
767                let mut new_data: Vec<T> = Vec::new();
768
769                let take_self = self.ncols;
770                let take_other = other.ncols;
771
772                for (idx, _) in self.data.iter().step_by(take_self).enumerate() {
773                    // Add from self, then other
774                    let row = (idx / take_self) * take_self;
775                    new_data.extend(self.data.iter().skip(row).take(take_self));
776                    new_data.extend(other.data.iter().skip(row).take(take_other));
777                }
778
779                let ncols = self.ncols + other.ncols;
780                let shape = (self.nrows, ncols);
781
782                return Ok(Self::new(new_data, shape).unwrap());
783            }
784        };
785    }
786
787    // TODO: Add option to transpose to be able to extend
788    // Doens't change anything if dimension mismatch
789
790    /// Extend a matrix with another on a dimension
791    ///
792    /// # Examples
793    ///
794    /// ```
795    /// use sukker::Matrix;
796    /// use sukker::Dimension;
797    ///
798    /// let mut matrix = Matrix::init(10.5, (4,4));
799    /// let matrix2 = Matrix::init(10.5, (4,1));
800    ///
801    /// matrix.extend(&matrix2, Dimension::Col);
802    ///
803    /// assert_eq!(matrix.shape(), (4,5));
804    /// ```
805    pub fn extend(&mut self, other: &Self, dim: Dimension) {
806        match dim {
807            Dimension::Row => {
808                if self.ncols != other.ncols {
809                    eprintln!("Error: Dimension mismatch");
810                    return;
811                }
812
813                self.data.extend(other.data.iter());
814
815                self.nrows += other.nrows;
816            }
817
818            Dimension::Col => {
819                if self.nrows != other.nrows {
820                    eprintln!("Error: Dimension mismatch");
821                    return;
822                }
823
824                let mut new_data: Vec<T> = Vec::new();
825
826                let take_self = self.ncols;
827                let take_other = other.ncols;
828
829                for (idx, _) in self.data.iter().step_by(take_self).enumerate() {
830                    // Add from self, then other
831                    let row = (idx / take_self) * take_self;
832                    new_data.extend(self.data.iter().skip(row).take(take_self));
833                    new_data.extend(other.data.iter().skip(row).take(take_other));
834                }
835
836                self.ncols += other.ncols;
837            }
838        };
839    }
840
841    ///  Sets element based on is and js
842    ///
843    ///  Sets nothing if you;re out of bounds
844    ///
845    /// # Examples
846    ///
847    /// ```
848    /// use sukker::Matrix;
849    ///
850    /// let mut matrix = Matrix::init(10.5, (2,3));
851    /// matrix.set(11.5, (1, 2));
852    ///
853    /// assert_eq!(matrix.get(1,2).unwrap(), 11.5);
854    /// ```
855    pub fn set(&mut self, value: T, idx: Shape) {
856        let idx = at!(idx.0, idx.1, self.ncols);
857
858        if idx >= self.size() {
859            eprintln!("Error: Index out of bounds. Not setting value.");
860            return;
861        }
862
863        self.data[idx] = value;
864    }
865
866    ///  Sets many elements based on vector of indeces
867    ///
868    ///  For indexes out of bounds, nothing is set
869    ///
870    /// # Examples
871    ///
872    /// ```
873    /// use sukker::Matrix;
874    ///
875    /// let mut matrix = Matrix::init(10.5, (2,3));
876    /// matrix.set_many(vec![(1,2), (1,1)], 11.5);
877    ///
878    /// assert_eq!(matrix.get(1,2).unwrap(), 11.5);
879    /// assert_eq!(matrix.get(1,1).unwrap(), 11.5);
880    /// assert_eq!(matrix.get(0,1).unwrap(), 10.5);
881    /// ```
882    pub fn set_many(&mut self, idx_list: Vec<Shape>, value: T) {
883        idx_list.iter().for_each(|&idx| self.set(value, idx));
884    }
885
886    /// Sets all elements of a matrix in a 1d range.
887    ///
888    /// The range is inclusive to stop, and will panic
889    /// if any indexes are out of range
890    ///
891    /// # Examples
892    ///
893    /// ```
894    /// use sukker::Matrix;
895    ///
896    /// let mut matrix = Matrix::init(10.5, (2,3));
897    /// matrix.set_range(0, 3, 11.5);
898    ///
899    /// assert_eq!(matrix.get(0,2).unwrap(), 11.5);
900    /// assert_eq!(matrix.get(0,1).unwrap(), 11.5);
901    /// assert_eq!(matrix.get(1,1).unwrap(), 10.5);
902    /// ```
903    pub fn set_range(&mut self, start: usize, stop: usize, value: T) {
904        (start..=stop).for_each(|i| self.data[i] = value);
905    }
906
907    /// Calculates the (row, col) for a matrix by a single index
908    ///
909    /// # Examples
910    ///
911    /// ```
912    /// use sukker::Matrix;
913    ///
914    /// let matrix = Matrix::init(10.5, (2,2));
915    /// let inv = matrix.one_to_2d_idx(1);
916    ///
917    /// assert_eq!(inv, (0,1));
918    /// ```
919    pub fn one_to_2d_idx(&self, idx: usize) -> Shape {
920        let row = idx / self.ncols;
921        let col = idx % self.ncols;
922
923        (row, col)
924    }
925
926    /// Finds maximum element in the matrix
927    ///
928    /// # Examples
929    ///
930    /// ```
931    /// use sukker::Matrix;
932    ///
933    /// let matrix = Matrix::init(10.5, (2,3));
934    ///
935    /// assert_eq!(matrix.max(), 10.5);
936    /// ```
937    pub fn max(&self) -> T {
938        // Matrix must have at least one element, thus we can unwrap
939        *self
940            .data
941            .par_iter()
942            .max_by(|a, b| a.partial_cmp(b).unwrap())
943            .unwrap()
944    }
945
946    /// Finds minimum element in the matrix
947    ///
948    /// # Examples
949    ///
950    /// ```
951    /// use sukker::Matrix;
952    ///
953    /// let mut matrix = Matrix::init(10.5, (2,3));
954    /// matrix.set(1.0, (0,2));
955    ///
956    /// assert_eq!(matrix.min(), 1.0);
957    /// ```
958    pub fn min(&self) -> T {
959        // Matrix must have at least one element, thus we can unwrap
960        *self
961            .data
962            .par_iter()
963            .min_by(|a, b| a.partial_cmp(b).unwrap())
964            .unwrap()
965    }
966
967    /// Finds position in matrix where value is highest.
968    /// Restricted to find this across a row or column
969    /// in the matrix.
970    ///
971    /// # Examples
972    ///
973    /// ```
974    /// use sukker::{Matrix, Dimension};
975    ///
976    /// let mut matrix = Matrix::init(1.0, (3,3));
977    /// matrix.set(15.0, (0,2));
978    ///
979    /// ```
980    fn argmax(&self, rowcol: usize, dimension: Dimension) -> Option<Shape> {
981        match dimension {
982            Dimension::Row => {
983                if rowcol >= self.nrows - 1 {
984                    return None;
985                }
986
987                let mut highest: T = T::one();
988                let mut i = 0;
989
990                for (idx, elem) in self
991                    .data
992                    .iter()
993                    .enumerate()
994                    .skip(rowcol * self.ncols)
995                    .take(self.ncols)
996                {
997                    if *elem >= highest {
998                        i = idx;
999                    }
1000                }
1001
1002                Some(self.one_to_2d_idx(i))
1003            }
1004
1005            Dimension::Col => {
1006                if rowcol >= self.ncols - 1 {
1007                    return None;
1008                }
1009
1010                let mut highest: T = T::one();
1011
1012                let mut i = 0;
1013
1014                for (idx, elem) in self
1015                    .data
1016                    .iter()
1017                    .enumerate()
1018                    .skip(rowcol)
1019                    .step_by(self.ncols)
1020                {
1021                    if *elem >= highest {
1022                        i = idx;
1023                    }
1024                }
1025
1026                Some(self.one_to_2d_idx(i))
1027            }
1028        }
1029    }
1030
1031    /// Finds position in matrix where value is lowest.
1032    /// Restricted to find this across a row or column
1033    /// in the matrix.
1034    ///
1035    ///
1036    /// # Examples
1037    ///
1038    /// ```
1039    /// use sukker::{Matrix, Dimension};
1040    ///
1041    /// let mut matrix = Matrix::init(10.5, (3,3));
1042    /// matrix.set(1.0, (0,1));
1043    ///
1044    /// // assert_eq!(matrix.argmin(1, Dimension::Col), Some(1));
1045    /// ```
1046    fn argmin(&self, rowcol: usize, dimension: Dimension) -> Option<Shape> {
1047        match dimension {
1048            Dimension::Row => {
1049                if rowcol >= self.nrows - 1 {
1050                    return None;
1051                }
1052
1053                let mut lowest: T = T::zero();
1054
1055                let mut i = 0;
1056
1057                for (idx, elem) in self
1058                    .data
1059                    .iter()
1060                    .enumerate()
1061                    .skip(rowcol * self.ncols)
1062                    .take(self.ncols)
1063                {
1064                    if *elem < lowest {
1065                        i = idx;
1066                    }
1067                }
1068
1069                Some(self.one_to_2d_idx(i))
1070            }
1071
1072            Dimension::Col => {
1073                if rowcol >= self.ncols - 1 {
1074                    return None;
1075                }
1076
1077                let mut lowest: T = T::zero();
1078
1079                let mut i = 0;
1080
1081                for (idx, elem) in self
1082                    .data
1083                    .iter()
1084                    .enumerate()
1085                    .skip(rowcol)
1086                    .step_by(self.ncols)
1087                {
1088                    if *elem <= lowest {
1089                        i = idx;
1090                    }
1091                }
1092
1093                Some(self.one_to_2d_idx(i))
1094            }
1095        }
1096    }
1097
1098    /// Finds total sum of matrix
1099    ///
1100    /// # Examples
1101    ///
1102    /// ```
1103    /// use sukker::Matrix;
1104    ///
1105    /// let matrix = Matrix::init(10f32, (2,2));
1106    ///
1107    /// assert_eq!(matrix.cumsum(), 40.0);
1108    /// ```
1109    pub fn cumsum(&self) -> T {
1110        if self.size() == 0 {
1111            return T::zero();
1112        }
1113
1114        self.data.par_iter().copied().sum()
1115    }
1116
1117    /// Multiplies  all elements in matrix
1118    ///
1119    /// # Examples
1120    ///
1121    /// ```
1122    /// use sukker::Matrix;
1123    ///
1124    /// let matrix = Matrix::init(10f32, (2,2));
1125    ///
1126    /// assert_eq!(matrix.cumprod(), 10000.0);
1127    /// ```
1128    pub fn cumprod(&self) -> T {
1129        if self.size() == 0 {
1130            return T::zero();
1131        }
1132
1133        self.data.par_iter().copied().product()
1134    }
1135
1136    /// Gets the average of the matrix
1137    ///
1138    /// # Examples
1139    ///
1140    /// ```
1141    /// use sukker::Matrix;
1142    ///
1143    /// let matrix = Matrix::init(10f32, (2,2));
1144    ///
1145    /// assert_eq!(matrix.avg(), 10.0);
1146    /// ```
1147    pub fn avg(&self) -> T {
1148        self.data.par_iter().copied().sum::<T>() / self.size().to_string().parse::<T>().unwrap()
1149    }
1150
1151    /// Gets the mean of the matrix
1152    ///
1153    /// # Examples
1154    ///
1155    /// ```
1156    /// use sukker::Matrix;
1157    ///
1158    /// let matrix = Matrix::init(10f32, (2,2));
1159    ///
1160    /// assert_eq!(matrix.mean(), 10.0);
1161    /// ```
1162    pub fn mean(&self) -> T {
1163        self.avg()
1164    }
1165
1166    /// Gets the median of the matrix
1167    ///
1168    /// # Examples
1169    ///
1170    /// ```
1171    /// use sukker::Matrix;
1172    ///
1173    /// let matrix = Matrix::new(vec![1.0, 4.0, 6.0, 5.0], (2,2)).unwrap();
1174    ///
1175    /// assert!(matrix.median() >= 4.45 && matrix.median() <= 4.55);
1176    /// ```
1177    pub fn median(&self) -> T {
1178        if self.size() == 1 {
1179            return self.at(0, 0);
1180        }
1181
1182        match self.data.len() % 2 {
1183            0 => {
1184                let half: usize = self.data.len() / 2;
1185
1186                self.data
1187                    .iter()
1188                    .sorted_by(|a, b| a.partial_cmp(&b).unwrap())
1189                    .skip(half - 1)
1190                    .take(2)
1191                    .copied()
1192                    .sum::<T>()
1193                    / (T::one() + T::one())
1194            }
1195            1 => {
1196                let half: usize = self.data.len() / 2;
1197
1198                self.data
1199                    .iter()
1200                    .sorted_by(|a, b| a.partial_cmp(&b).unwrap())
1201                    .nth(half)
1202                    .copied()
1203                    .unwrap()
1204            }
1205            _ => unreachable!(),
1206        }
1207    }
1208
1209    /// Sums up elements over given axis and dimension.
1210    /// Will return 0 if you're out of bounds
1211    ///
1212    /// sum(2, Dimension::Col) means summig up these ones
1213    ///
1214    /// [ 10 10 (10) 10 10
1215    ///   10 10 (10) 10 10
1216    ///   10 10 (10) 10 10
1217    ///   10 10 (10) 10 10
1218    ///   10 10 (10) 10 10 ]
1219    ///
1220    ///   = 10 * 5 = 50
1221    ///
1222    /// # Examples
1223    ///
1224    /// ```
1225    /// use sukker::Matrix;
1226    /// use sukker::Dimension;
1227    ///
1228    /// let matrix = Matrix::init(10f32, (5,5));
1229    ///
1230    /// assert_eq!(matrix.sum(0, Dimension::Row), 50.0);
1231    /// assert_eq!(matrix.sum(3, Dimension::Col), 50.0);
1232    /// ```
1233    pub fn sum(&self, rowcol: usize, dimension: Dimension) -> T {
1234        // TODO: Add out of bounds options
1235        if self.size() == 1 {
1236            return self.at(0, 0);
1237        }
1238
1239        match dimension {
1240            Dimension::Row => self
1241                .data
1242                .par_iter()
1243                .skip(rowcol * self.ncols)
1244                .take(self.ncols)
1245                .copied()
1246                .sum(),
1247            Dimension::Col => self
1248                .data
1249                .par_iter()
1250                .skip(rowcol)
1251                .step_by(self.ncols)
1252                .copied()
1253                .sum(),
1254        }
1255    }
1256
1257    /// Prods up elements over given rowcol and dimension
1258    /// Will return 1 if you're out of bounds.
1259    ///
1260    /// See `sum` for example on how this is calculated
1261    ///
1262    /// # Examples
1263    ///
1264    /// ```
1265    /// use sukker::Matrix;
1266    /// use sukker::Dimension;
1267    ///
1268    /// let matrix = Matrix::init(10f32, (2,2));
1269    ///
1270    /// assert_eq!(matrix.prod(0, Dimension::Row), 100.0);
1271    /// assert_eq!(matrix.prod(0, Dimension::Col), 100.0);
1272    /// ```
1273    pub fn prod(&self, rowcol: usize, dimension: Dimension) -> T {
1274        match dimension {
1275            Dimension::Row => self
1276                .data
1277                .par_iter()
1278                .skip(rowcol * self.ncols)
1279                .take(self.ncols)
1280                .copied()
1281                .product(),
1282            Dimension::Col => self
1283                .data
1284                .par_iter()
1285                .skip(rowcol)
1286                .step_by(self.ncols)
1287                .copied()
1288                .product(),
1289        }
1290    }
1291}
1292
1293/// Linalg on floats
1294impl<'a, T> LinAlgFloats<'a, T> for Matrix<'a, T>
1295where
1296    T: MatrixElement + Float + 'a,
1297    <T as FromStr>::Err: Error + 'static,
1298    Vec<T>: IntoParallelIterator,
1299    Vec<&'a T>: IntoParallelRefIterator<'a>,
1300{
1301    /// Takes the logarithm of each element
1302    ///
1303    /// # Examples
1304    ///
1305    /// ```
1306    /// use sukker::{Matrix, LinAlgFloats};
1307    ///
1308    /// let matrix = Matrix::init(10.0, (2,2));
1309    /// let result = matrix.log(10.0);
1310    ///
1311    /// assert_eq!(result.all(|&e| e == 1.0), true);
1312    ///
1313    /// ```
1314    fn log(&self, base: T) -> Self {
1315        let data: Vec<T> = self.data.par_iter().map(|&e| e.log(base)).collect();
1316
1317        Self::new(data, self.shape()).unwrap()
1318    }
1319
1320    /// Takes the natural logarithm of each element in a matrix
1321    ///
1322    /// # Examples
1323    ///
1324    /// ```
1325    /// use sukker::{Matrix, LinAlgFloats};
1326    /// use sukker::constants::EF64;
1327    ///
1328    /// let matrix: Matrix<f64> = Matrix::init(EF64, (2,2));
1329    ///
1330    /// let res = matrix.ln();
1331    /// ```
1332    fn ln(&self) -> Self {
1333        let data: Vec<T> = self.data.par_iter().map(|&e| e.ln()).collect();
1334
1335        Self::new(data, self.shape()).unwrap()
1336    }
1337
1338    /// Takes the square root of each element in a matrix.
1339    /// If some elements are negative, these will be kept the same
1340    ///
1341    /// # Examples
1342    ///
1343    /// ```
1344    /// use sukker::{Matrix, LinAlgFloats};
1345    ///
1346    /// let matrix: Matrix<f64> = Matrix::init(9.0, (3,3));
1347    ///
1348    /// let res = matrix.sqrt();
1349    ///
1350    /// assert_eq!(res.all(|&e| e == 3.0), true);
1351    /// ```
1352    fn sqrt(&self) -> Self {
1353        let data: Vec<T> = self
1354            .data
1355            .par_iter()
1356            .map(|&e| if e > T::zero() { e.sqrt() } else { e })
1357            .collect();
1358
1359        Self::new(data, self.shape()).unwrap()
1360    }
1361
1362    /// Gets sin of every value
1363    ///
1364    /// # Examples
1365    ///
1366    /// ```
1367    /// use sukker::{Matrix, LinAlgFloats};
1368    ///
1369    /// let matrix = Matrix::init(1.0, (2,2));
1370    ///
1371    /// let res = matrix.sin();
1372    /// ```
1373    fn sin(&self) -> Self {
1374        let data: Vec<T> = self.data.par_iter().map(|&e| e.sin()).collect();
1375
1376        Self::new(data, self.shape()).unwrap()
1377    }
1378
1379    /// Gets cos of every value
1380    ///
1381    /// # Examples
1382    ///
1383    /// ```
1384    /// use sukker::{Matrix, LinAlgFloats};
1385    /// use sukker::constants::EF32;
1386    ///
1387    /// let matrix = Matrix::init(EF32, (2,2));
1388    ///
1389    /// let res = matrix.cos();
1390    /// ```
1391    fn cos(&self) -> Self {
1392        let data: Vec<T> = self.data.par_iter().map(|&e| e.cos()).collect();
1393
1394        Self::new(data, self.shape()).unwrap()
1395    }
1396
1397    /// Gets tan of every value
1398    ///
1399    /// # Examples
1400    ///
1401    /// ```
1402    /// use sukker::{Matrix, LinAlgFloats};
1403    /// use sukker::constants::EF32;
1404    ///
1405    /// let matrix = Matrix::init(EF32, (2,2));
1406    ///
1407    /// let res = matrix.tan();
1408    /// ```
1409    fn tan(&self) -> Self {
1410        let data: Vec<T> = self.data.par_iter().map(|&e| e.tan()).collect();
1411
1412        Self::new(data, self.shape()).unwrap()
1413    }
1414
1415    /// Gets sinh of every value
1416    ///
1417    /// # Examples
1418    ///
1419    /// ```
1420    /// use sukker::{Matrix, LinAlgFloats};
1421    /// use sukker::constants::EF32;
1422    ///
1423    /// let matrix = Matrix::init(EF32, (2,2));
1424    ///
1425    /// let res = matrix.sinh();
1426    /// ```
1427    fn sinh(&self) -> Self {
1428        let data: Vec<T> = self.data.par_iter().map(|&e| e.sinh()).collect();
1429
1430        Self::new(data, self.shape()).unwrap()
1431    }
1432
1433    /// Gets cosh of every value
1434    ///
1435    /// # Examples
1436    ///
1437    /// ```
1438    /// use sukker::{Matrix, LinAlgFloats};
1439    /// use sukker::constants::EF32;
1440    ///
1441    /// let matrix = Matrix::init(EF32, (2,2));
1442    ///
1443    /// let res = matrix.cosh();
1444    /// ```
1445    fn cosh(&self) -> Self {
1446        let data: Vec<T> = self.data.par_iter().map(|&e| e.cosh()).collect();
1447
1448        Self::new(data, self.shape()).unwrap()
1449    }
1450
1451    /// Gets tanh of every value
1452    ///
1453    /// # Examples
1454    ///
1455    /// ```
1456    /// use sukker::{Matrix, LinAlgFloats};
1457    /// use sukker::constants::EF32;
1458    ///
1459    /// let matrix = Matrix::init(EF32, (2,2));
1460    ///
1461    /// let res = matrix.tanh();
1462    /// ```
1463    fn tanh(&self) -> Self {
1464        let data: Vec<T> = self.data.par_iter().map(|&e| e.tanh()).collect();
1465
1466        Self::new(data, self.shape()).unwrap()
1467    }
1468
1469    /// Find the eigenvale of a matrix
1470    ///
1471    /// # Examples
1472    ///
1473    /// ```
1474    /// use sukker::Matrix;
1475    ///
1476    /// let mut matrix = Matrix::init(2.0, (2,100));
1477    ///
1478    /// ```
1479    fn get_eigenvalues(&self) -> Option<Vec<T>> {
1480        todo!()
1481    }
1482
1483    /// Find the eigenvectors
1484    fn get_eigenvectors(&self) -> Option<Vec<T>> {
1485        unimplemented!()
1486    }
1487}
1488
1489/// trait MatrixLinAlg contains all common Linear Algebra functions to be
1490/// performed on matrices
1491impl<'a, T> Matrix<'a, T>
1492where
1493    T: MatrixElement,
1494    <T as FromStr>::Err: Error + 'static,
1495    Vec<T>: IntoParallelIterator,
1496    Vec<&'a T>: IntoParallelRefIterator<'a>,
1497{
1498    /// Adds one matrix to another
1499    ///
1500    /// # Examples
1501    ///
1502    /// ```
1503    /// use sukker::Matrix;
1504    ///
1505    /// let matrix1 = Matrix::init(10.0, (2,2));
1506    /// let matrix2 = Matrix::init(10.0, (2,2));
1507    ///
1508    /// assert_eq!(matrix1.add(&matrix2).unwrap().get(0,0).unwrap(), 20.0);
1509    /// ```
1510    pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
1511        if self.shape() != other.shape() {
1512            return Err(MatrixError::MatrixDimensionMismatchError.into());
1513        }
1514
1515        let data: Vec<T> = self
1516            .data
1517            .iter()
1518            .zip(other.data.iter())
1519            .map(|(&x, &y)| x + y)
1520            .collect();
1521
1522        Ok(Self::new(data, self.shape()).unwrap())
1523    }
1524
1525    /// Subtracts one matrix from another
1526    ///
1527    /// # Examples
1528    ///
1529    /// ```
1530    /// use sukker::Matrix;
1531    ///
1532    /// let matrix1 = Matrix::init(10.0, (2,2));
1533    /// let matrix2 = Matrix::init(10.0, (2,2));
1534    ///
1535    /// assert_eq!(matrix1.sub(&matrix2).unwrap().get(1,0).unwrap(), 0.0);
1536    /// ```
1537    pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
1538        if self.shape() != other.shape() {
1539            return Err(MatrixError::MatrixDimensionMismatchError.into());
1540        }
1541
1542        let data: Vec<T> = self
1543            .data
1544            .iter()
1545            .zip(other.data.iter())
1546            .map(|(&x, &y)| x - y)
1547            .collect();
1548
1549        Ok(Self::new(data, self.shape()).unwrap())
1550    }
1551
1552    /// Subtracts one array from another and returns the absolute value
1553    ///
1554    /// # Examples
1555    ///
1556    /// ```
1557    /// use sukker::Matrix;
1558    ///
1559    /// let matrix1 = Matrix::init(10.0f32, (2,2));
1560    /// let matrix2 = Matrix::init(15.0f32, (2,2));
1561    ///
1562    /// assert_eq!(matrix1.sub_abs(&matrix2).unwrap().get(0,0).unwrap(), 5.0);
1563    /// ```
1564    pub fn sub_abs(&self, other: &Self) -> Result<Self, MatrixError> {
1565        if self.shape() != other.shape() {
1566            return Err(MatrixError::MatrixDimensionMismatchError.into());
1567        }
1568
1569        let data = self
1570            .data
1571            .iter()
1572            .zip(other.data.iter())
1573            .map(|(&x, &y)| if x > y { x - y } else { y - x })
1574            .collect_vec();
1575
1576        Ok(Self::new(data, self.shape()).unwrap())
1577    }
1578
1579    /// Dot product of two matrices
1580    ///
1581    /// # Examples
1582    ///
1583    /// ```
1584    /// use sukker::Matrix;
1585    ///
1586    /// let matrix1 = Matrix::init(20.0, (2,2));
1587    /// let matrix2 = Matrix::init(10.0, (2,2));
1588    ///
1589    /// assert_eq!(matrix1.mul(&matrix2).unwrap().get(0,0).unwrap(), 200.0);
1590    /// ```
1591    pub fn mul(&self, other: &Self) -> Result<Self, MatrixError> {
1592        if self.shape() != other.shape() {
1593            return Err(MatrixError::MatrixDimensionMismatchError.into());
1594        }
1595
1596        let data = self
1597            .data
1598            .iter()
1599            .zip(other.data.iter())
1600            .map(|(&x, &y)| x * y)
1601            .collect_vec();
1602
1603        Ok(Self::new(data, self.shape()).unwrap())
1604    }
1605
1606    /// Dot product of two matrices
1607    ///
1608    /// # Examples
1609    ///
1610    /// ```
1611    /// use sukker::Matrix;
1612    ///
1613    /// let matrix1 = Matrix::init(20.0, (2,2));
1614    /// let matrix2 = Matrix::init(10.0, (2,2));
1615    ///
1616    /// assert_eq!(matrix1.dot(&matrix2).unwrap().get(0,0).unwrap(), 200.0);
1617    /// ```
1618    pub fn dot(&self, other: &Self) -> Result<Self, MatrixError> {
1619        self.mul(other)
1620    }
1621
1622    /// Bad handling of zero div
1623    ///
1624    /// # Examples
1625    ///
1626    /// ```
1627    /// use sukker::Matrix;
1628    ///
1629    /// let matrix1 = Matrix::init(20.0, (2,2));
1630    /// let matrix2 = Matrix::init(10.0, (2,2));
1631    ///
1632    /// assert_eq!(matrix1.div(&matrix2).unwrap().get(0,0).unwrap(), 2.0);
1633    /// ```
1634    pub fn div(&self, other: &Self) -> Result<Self, MatrixError> {
1635        if self.shape() != other.shape() {
1636            return Err(MatrixError::MatrixDimensionMismatchError.into());
1637        }
1638
1639        if other.any(|e| e == &T::zero()) {
1640            return Err(MatrixError::MatrixDivideByZeroError.into());
1641        }
1642
1643        let data = self
1644            .data
1645            .iter()
1646            .zip(other.data.iter())
1647            .map(|(&x, &y)| x / y)
1648            .collect_vec();
1649
1650        Ok(Self::new(data, self.shape()).unwrap())
1651    }
1652
1653    /// Negates every value in the matrix
1654    ///
1655    /// # Examples
1656    ///
1657    /// ```
1658    /// use sukker::{Matrix, LinAlgFloats};
1659    ///
1660    /// let matrix = Matrix::<f32>::ones((2,2));
1661    ///
1662    /// let negated = matrix.neg();
1663    ///
1664    /// assert_eq!(negated.all(|&e| e == -1.0), true);
1665    /// ```
1666    pub fn neg(&self) -> Self {
1667        let data: Vec<T> = self.data.par_iter().map(|&e| e.neg()).collect();
1668
1669        Self::new(data, self.shape()).unwrap()
1670    }
1671
1672    /// Adds a value to a matrix and returns a new matrix
1673    ///
1674    /// # Examples
1675    ///
1676    /// ```
1677    /// use sukker::Matrix;
1678    ///
1679    /// let matrix = Matrix::init(20.0, (2,2));
1680    /// let value: f32 = 2.0;
1681    /// assert_eq!(matrix.add_val(value).get(0,0).unwrap(), 22.0);
1682    /// ```
1683    pub fn add_val(&self, val: T) -> Self {
1684        let data: Vec<T> = self.data.par_iter().map(|&e| e + val).collect();
1685
1686        Self::new(data, self.shape()).unwrap()
1687    }
1688
1689    /// Substracts a value to a matrix and returns a new matrix
1690    ///
1691    /// # Examples
1692    ///
1693    /// ```
1694    /// use sukker::Matrix;
1695    ///
1696    /// let matrix = Matrix::init(20.0, (2,2));
1697    /// let value: f32 = 2.0;
1698    /// assert_eq!(matrix.sub_val(value).get(0,0).unwrap(), 18.0);
1699    /// ```
1700    pub fn sub_val(&self, val: T) -> Self {
1701        let data: Vec<T> = self.data.par_iter().map(|&e| e - val).collect();
1702
1703        Self::new(data, self.shape()).unwrap()
1704    }
1705
1706    /// Multiplies a value to a matrix and returns a new matrix
1707    ///
1708    /// # Examples
1709    ///
1710    /// ```
1711    /// use sukker::Matrix;
1712    ///
1713    /// let matrix = Matrix::init(20.0, (2,2));
1714    /// let value: f32 = 2.0;
1715    /// assert_eq!(matrix.mul_val(value).get(0,0).unwrap(), 40.0);
1716    /// ```
1717    pub fn mul_val(&self, val: T) -> Self {
1718        let data: Vec<T> = self.data.par_iter().map(|&e| e * val).collect();
1719
1720        Self::new(data, self.shape()).unwrap()
1721    }
1722
1723    /// Divides a value to a matrix and returns a new matrix
1724    ///
1725    /// # Examples
1726    ///
1727    /// ```
1728    /// use sukker::Matrix;
1729    ///
1730    /// let matrix = Matrix::init(20.0, (2,2));
1731    /// let value: f32 = 2.0;
1732    ///
1733    /// let result_mat = matrix.div_val(value);
1734    ///
1735    /// assert_eq!(result_mat.get(0,0).unwrap(), 10.0);
1736    /// ```
1737    pub fn div_val(&self, val: T) -> Self {
1738        let data: Vec<T> = self.data.par_iter().map(|&e| e / val).collect();
1739
1740        Self::new(data, self.shape()).unwrap()
1741    }
1742
1743    /// Pows each value in a matrix by val times
1744    ///
1745    /// # Examples
1746    ///
1747    /// ```
1748    /// use sukker::Matrix;
1749    ///
1750    /// let matrix = Matrix::init(2.0, (2,2));
1751    ///
1752    /// let result_mat = matrix.pow(2);
1753    ///
1754    /// assert_eq!(result_mat.get_vec(), vec![4.0, 4.0, 4.0, 4.0]);
1755    /// ```
1756    pub fn pow(&self, val: usize) -> Self {
1757        let data: Vec<T> = self.data.par_iter().map(|&e| pow(e, val)).collect();
1758
1759        Self::new(data, self.shape()).unwrap()
1760    }
1761
1762    /// Takes the absolute values of the matrix
1763    ///
1764    /// # Examples
1765    ///
1766    /// ```
1767    /// use sukker::Matrix;
1768    ///
1769    /// let matrix = Matrix::init(-20.0, (2,2));
1770    ///
1771    /// let res = matrix.abs();
1772    ///
1773    /// assert_eq!(res.all(|&e| e == 20.0), true);
1774    /// ```
1775    pub fn abs(&self) -> Self {
1776        let data: Vec<T> = self.data.par_iter().map(|&e| e.abs()).collect();
1777
1778        Self::new(data, self.shape()).unwrap()
1779    }
1780
1781    /// Multiply a matrix with itself n number of times.
1782    /// This is done by performing a matrix multiplication
1783    /// several time on self and the result of mat.exp(i-1).
1784    ///
1785    /// If matrix is not in form NxN, this function returns None
1786    ///
1787    /// Examples
1788    ///
1789    /// ```
1790    /// use sukker::Matrix;
1791    ///
1792    /// let a = Matrix::<i32>::init(2, (2,2));
1793    ///
1794    /// let res = a.exp(3).unwrap();
1795    ///
1796    /// assert_eq!(res.all(|&e| e == 32), true);
1797    /// ```
1798    pub fn exp(&self, n: usize) -> Option<Self> {
1799        if self.nrows != self.ncols {
1800            return None;
1801        }
1802
1803        let mut res = self.clone();
1804
1805        (0..n - 1).for_each(|_| res = res.matmul(self).unwrap());
1806
1807        Some(res)
1808    }
1809
1810    /// Adds a matrix in-place to a matrix
1811    ///
1812    /// # Examples
1813    ///
1814    /// ```
1815    /// use sukker::Matrix;
1816    ///
1817    /// let mut matrix1 = Matrix::init(20.0, (2,2));
1818    /// let matrix2 = Matrix::init(2.0, (2,2));
1819    ///
1820    /// matrix1.add_self(&matrix2);
1821    ///
1822    /// assert_eq!(matrix1.get(0,0).unwrap(), 22.0);
1823    /// ```
1824    pub fn add_self(&mut self, other: &Self) {
1825        self.data
1826            .par_iter_mut()
1827            .zip(&other.data)
1828            .for_each(|(a, b)| *a += *b);
1829    }
1830
1831    /// Subtracts a matrix in-place to a matrix
1832    ///
1833    /// # Examples
1834    ///
1835    /// ```
1836    /// use sukker::Matrix;
1837    ///
1838    /// let mut matrix1 = Matrix::init(20.0, (2,2));
1839    /// let matrix2 = Matrix::init(2.0, (2,2));
1840    ///
1841    /// matrix1.sub_self(&matrix2);
1842    ///
1843    /// assert_eq!(matrix1.get(0,0).unwrap(), 18.0);
1844    /// ```
1845    pub fn sub_self(&mut self, other: &Self) {
1846        self.data
1847            .par_iter_mut()
1848            .zip(&other.data)
1849            .for_each(|(a, b)| *a -= *b);
1850    }
1851
1852    /// Multiplies a matrix in-place to a matrix
1853    ///
1854    /// # Examples
1855    ///
1856    /// ```
1857    /// use sukker::Matrix;
1858    ///
1859    /// let mut matrix1 = Matrix::init(20.0, (2,2));
1860    /// let matrix2 = Matrix::init(2.0, (2,2));
1861    ///
1862    /// matrix1.mul_self(&matrix2);
1863    ///
1864    /// assert_eq!(matrix1.get(0,0).unwrap(), 40.0);
1865    /// ```
1866    pub fn mul_self(&mut self, other: &Self) {
1867        self.data
1868            .par_iter_mut()
1869            .zip(&other.data)
1870            .for_each(|(a, b)| *a *= *b);
1871    }
1872
1873    /// Divides a matrix in-place to a matrix
1874    ///
1875    /// # Examples
1876    ///
1877    /// ```
1878    /// use sukker::Matrix;
1879    ///
1880    /// let mut matrix1 = Matrix::init(20.0, (2,2));
1881    /// let matrix2 = Matrix::init(2.0, (2,2));
1882    ///
1883    /// matrix1.div_self(&matrix2);
1884    ///
1885    /// assert_eq!(matrix1.get(0,0).unwrap(), 10.0);
1886    /// ```
1887    pub fn div_self(&mut self, other: &Self) {
1888        self.data
1889            .par_iter_mut()
1890            .zip(&other.data)
1891            .for_each(|(a, b)| *a /= *b);
1892    }
1893
1894    /// Abs matrix in-place to a matrix
1895    ///
1896    /// # Examples
1897    ///
1898    /// ```
1899    /// use sukker::Matrix;
1900    ///
1901    /// let mut matrix = Matrix::init(20.0, (2,2));
1902    ///
1903    /// matrix.abs_self()
1904    ///
1905    /// // assert_eq!(matrix1.get(0,0).unwrap(), 22.0);
1906    /// ```
1907    pub fn abs_self(&mut self) {
1908        self.data.par_iter_mut().for_each(|e| *e = abs(*e))
1909    }
1910
1911    /// Adds a value in-place to a matrix
1912    ///
1913    /// # Examples
1914    ///
1915    /// ```
1916    /// use sukker::Matrix;
1917    ///
1918    /// let mut matrix = Matrix::init(20.0, (2,2));
1919    /// let value: f32 = 2.0;
1920    ///
1921    /// matrix.add_val_self(value);
1922    ///
1923    /// assert_eq!(matrix.get(0,0).unwrap(), 22.0);
1924    /// ```
1925    pub fn add_val_self(&mut self, val: T) {
1926        self.data.par_iter_mut().for_each(|e| *e += val);
1927    }
1928
1929    /// Subtracts a value in-place to a matrix
1930    ///
1931    /// # Examples
1932    ///
1933    /// ```
1934    /// use sukker::Matrix;
1935    ///
1936    /// let mut matrix = Matrix::init(20.0, (2,2));
1937    /// let value: f32 = 2.0;
1938    ///
1939    /// matrix.sub_val_self(value);
1940    ///
1941    /// assert_eq!(matrix.get(0,0).unwrap(), 18.0);
1942    /// ```
1943    pub fn sub_val_self(&mut self, val: T) {
1944        self.data.par_iter_mut().for_each(|e| *e -= val);
1945    }
1946
1947    /// Mults a value in-place to a matrix
1948    ///
1949    /// # Examples
1950    ///
1951    /// ```
1952    /// use sukker::Matrix;
1953    ///
1954    /// let mut matrix = Matrix::init(20.0, (2,2));
1955    /// let value: f32 = 2.0;
1956    ///
1957    /// matrix.mul_val_self(value);
1958    ///
1959    /// assert_eq!(matrix.get(0,0).unwrap(), 40.0);
1960    /// ```
1961    pub fn mul_val_self(&mut self, val: T) {
1962        self.data.par_iter_mut().for_each(|e| *e *= val);
1963    }
1964
1965    /// Divs a value in-place to a matrix
1966    ///
1967    /// # Examples
1968    ///
1969    /// ```
1970    /// use sukker::Matrix;
1971    ///
1972    /// let mut matrix = Matrix::init(20.0, (2,2));
1973    /// let value: f32 = 2.0;
1974    ///
1975    /// matrix.div_val_self(value);
1976    ///
1977    /// assert_eq!(matrix.get(0,0).unwrap(), 10.0);
1978    /// ```
1979    pub fn div_val_self(&mut self, val: T) {
1980        self.data.par_iter_mut().for_each(|e| *e /= val);
1981    }
1982
1983    /// Transposed matrix multiplications
1984    ///
1985    /// # Examples
1986    ///
1987    /// ```
1988    /// use sukker::Matrix;
1989    ///
1990    /// let mut matrix1 = Matrix::init(2.0, (2,4));
1991    /// let matrix2 = Matrix::init(2.0, (4,2));
1992    ///
1993    /// let result = matrix1.matmul(&matrix2).unwrap();
1994    ///
1995    /// assert_eq!(result.get(0,0).unwrap(), 16.0);
1996    /// assert_eq!(result.shape(), (2,2));
1997    /// ```
1998    pub fn matmul(&self, other: &Self) -> Result<Self, MatrixError> {
1999        // assert M N x N P
2000        if self.ncols != other.nrows {
2001            return Err(MatrixError::MatrixDimensionMismatchError.into());
2002        }
2003
2004        Ok(self.matmul_helper(other))
2005    }
2006
2007    /// Shorthand method for matmul
2008    pub fn mm(&self, other: &Self) -> Result<Self, MatrixError> {
2009        self.matmul(other)
2010    }
2011
2012    /// Get's the determinat of a N x N matrix
2013    ///
2014    /// Examples
2015    ///
2016    /// ```
2017    /// use sukker::Matrix;
2018    ///
2019    /// let mat: Matrix<i32> = Matrix::new(vec![1,3,5,9,1,3,1,7,4,3,9,7,5,2,0,9], (4,4)).unwrap();
2020    ///
2021    ///
2022    /// let res = mat.determinant().unwrap();
2023    ///
2024    /// assert_eq!(res, -376);
2025    /// ```
2026    pub fn determinant(&self) -> Option<T> {
2027        if self.nrows != self.ncols {
2028            return None;
2029        }
2030
2031        Some(self.determinant_helper())
2032    }
2033
2034    /// Shorthand call for `determinant`
2035    pub fn det(&self) -> Option<T> {
2036        self.determinant()
2037    }
2038
2039    /// Finds the inverse of a matrix if possible
2040    ///
2041    /// Definition: AA^-1 = A^-1A = I
2042    ///
2043    /// Examples
2044    ///
2045    /// ```
2046    /// use sukker::Matrix;
2047    ///
2048    /// let matrix = Matrix::new(vec![4,7,2,6], (2,2)).unwrap();
2049    ///
2050    /// // let inverse  = matrix.inverse();
2051    ///
2052    /// ```
2053    pub fn inverse(&self) -> Option<Self> {
2054        if self.shape() != (2, 2) {
2055            eprintln!("Function not implemented for inverse on larger matrices yet!");
2056            return None;
2057        }
2058        if self.nrows != self.ncols {
2059            eprintln!("Oops");
2060            return None;
2061        }
2062
2063        if self.determinant().unwrap() == T::zero() {
2064            return None;
2065        }
2066
2067        let a = self.at(0, 0);
2068        let b = self.at(0, 1);
2069        let c = self.at(1, 0);
2070        let d = self.at(1, 1);
2071
2072        let mut mat = Self::new(vec![d, -b, -c, a], self.shape()).unwrap();
2073
2074        mat.mul_val_self(T::one() / (a * d - b * c));
2075
2076        return Some(mat);
2077
2078        // let mut inverse = Self::zeros_like(self);
2079        //
2080        // let identity_mat = Self::eye_like(self);
2081        //
2082        // Some(inverse)
2083    }
2084
2085    /// Transpose a matrix in-place
2086    ///
2087    /// # Examples
2088    ///
2089    /// ```
2090    /// use sukker::Matrix;
2091    ///
2092    /// let mut matrix = Matrix::init(2.0, (2,100));
2093    /// matrix.transpose();
2094    ///
2095    /// assert_eq!(matrix.shape(), (100,2));
2096    /// ```
2097    pub fn transpose(&mut self) {
2098        for i in 0..self.nrows {
2099            for j in (i + 1)..self.ncols {
2100                let lhs = at!(i, j, self.ncols);
2101                let rhs = at!(j, i, self.nrows);
2102                self.data.swap(lhs, rhs);
2103            }
2104        }
2105
2106        swap(&mut self.nrows, &mut self.ncols);
2107    }
2108
2109    /// Shorthand call for transpose
2110    ///
2111    /// # Examples
2112    ///
2113    /// ```
2114    /// use sukker::Matrix;
2115    ///
2116    /// let mut matrix = Matrix::init(2.0, (2,100));
2117    /// matrix.t();
2118    ///
2119    /// assert_eq!(matrix.shape(), (100,2));
2120    /// ```
2121    pub fn t(&mut self) {
2122        self.transpose()
2123    }
2124
2125    /// Transpose a matrix and return a copy
2126    ///
2127    /// # Examples
2128    ///
2129    /// ```
2130    /// use sukker::Matrix;
2131    ///
2132    /// let matrix = Matrix::init(2.0, (2,100));
2133    /// let result = matrix.transpose_copy();
2134    ///
2135    /// assert_eq!(result.shape(), (100,2));
2136    /// ```
2137    pub fn transpose_copy(&self) -> Self {
2138        let mut res = self.clone();
2139        res.transpose();
2140        res
2141    }
2142}
2143
2144/// Implementations for predicates
2145impl<'a, T> Matrix<'a, T>
2146where
2147    T: MatrixElement,
2148    <T as FromStr>::Err: Error + 'static,
2149    Vec<T>: IntoParallelIterator,
2150    Vec<&'a T>: IntoParallelRefIterator<'a>,
2151{
2152    /// Counts all occurances where predicate holds
2153    ///
2154    /// # Examples
2155    ///
2156    /// ```
2157    /// use sukker::Matrix;
2158    ///
2159    /// let matrix = Matrix::init(2.0f32, (2,4));
2160    ///
2161    /// assert_eq!(matrix.count_where(|&e| e == 2.0), 8);
2162    /// ```
2163    pub fn count_where<F>(&'a self, pred: F) -> usize
2164    where
2165        F: Fn(&T) -> bool + Sync,
2166    {
2167        self.data.par_iter().filter(|&e| pred(e)).count()
2168    }
2169
2170    /// Sums all occurances where predicate holds
2171    ///
2172    /// # Examples
2173    ///
2174    /// ```
2175    /// use sukker::Matrix;
2176    ///
2177    /// let matrix = Matrix::init(2.0, (2,4));
2178    ///
2179    /// assert_eq!(matrix.sum_where(|&e| e == 2.0), 16.0);
2180    /// ```
2181    pub fn sum_where<F>(&self, pred: F) -> T
2182    where
2183        F: Fn(&T) -> bool + Sync,
2184    {
2185        self.data
2186            .par_iter()
2187            .filter(|&e| pred(e))
2188            .copied()
2189            .sum::<T>()
2190    }
2191
2192    /// Setsall elements where predicate holds true.
2193    /// The new value is to be set inside the predicate as well
2194    ///
2195    /// # Examples
2196    ///
2197    /// ```
2198    /// use sukker::Matrix;
2199    ///
2200    /// let mut matrix = Matrix::init(2.0, (2,4));
2201    ///
2202    /// assert_eq!(matrix.get(0,0).unwrap(), 2.0);
2203    ///
2204    /// matrix.set_where(|e| {
2205    ///     if *e == 2.0 {
2206    ///         *e = 2.3;
2207    ///     }
2208    /// });
2209    ///
2210    /// assert_eq!(matrix.get(0,0).unwrap(), 2.3);
2211    /// ```
2212    pub fn set_where<P>(&mut self, mut pred: P)
2213    where
2214        P: FnMut(&mut T) + Sync + Send,
2215    {
2216        self.data.iter_mut().for_each(|e| pred(e));
2217    }
2218
2219    /// Return whether or not a predicate holds at least once
2220    ///
2221    /// # Examples
2222    ///
2223    /// ```
2224    /// use sukker::Matrix;
2225    ///
2226    /// let matrix = Matrix::init(2.0, (2,4));
2227    ///
2228    /// assert_eq!(matrix.any(|&e| e == 2.0), true);
2229    /// ```
2230    pub fn any<F>(&self, pred: F) -> bool
2231    where
2232        F: Fn(&T) -> bool + Sync + Send,
2233    {
2234        self.data.par_iter().any(pred)
2235    }
2236
2237    /// Returns whether or not predicate holds for all values
2238    ///
2239    /// # Examples
2240    ///
2241    /// ```
2242    /// use sukker::Matrix;
2243    ///
2244    /// let matrix = Matrix::randomize_range(1.0, 4.0, (2,4));
2245    ///
2246    /// assert_eq!(matrix.all(|&e| e >= 1.0), true);
2247    /// ```
2248    pub fn all<F>(&self, pred: F) -> bool
2249    where
2250        F: Fn(&T) -> bool + Sync + Send,
2251    {
2252        self.data.par_iter().all(pred)
2253    }
2254
2255    /// Finds first index where predicates holds if possible
2256    ///
2257    /// # Examples
2258    ///
2259    /// ```
2260    /// use sukker::Matrix;
2261    ///
2262    /// let matrix = Matrix::init(2f32, (2,4));
2263    ///
2264    /// assert_eq!(matrix.find(|&e| e >= 1f32), Some((0,0)));
2265    /// ```
2266    pub fn find<F>(&self, pred: F) -> Option<Shape>
2267    where
2268        F: Fn(&T) -> bool + Sync,
2269    {
2270        if let Some((idx, _)) = self.data.iter().find_position(|&e| pred(e)) {
2271            return Some(self.one_to_2d_idx(idx));
2272        }
2273
2274        None
2275    }
2276
2277    /// Finds all indeces where predicates holds if possible
2278    ///
2279    /// # Examples
2280    ///
2281    /// ```
2282    /// use sukker::Matrix;
2283    ///
2284    /// let matrix = Matrix::init(2.0, (2,4));
2285    ///
2286    /// assert_eq!(matrix.find_all(|&e| e >= 3.0), None);
2287    /// ```
2288    pub fn find_all<F>(&self, pred: F) -> Option<Vec<Shape>>
2289    where
2290        F: Fn(&T) -> bool + Sync,
2291    {
2292        let data: Vec<Shape> = self
2293            .data
2294            .par_iter()
2295            .enumerate()
2296            .filter_map(|(idx, elem)| {
2297                if pred(elem) {
2298                    Some(self.one_to_2d_idx(idx))
2299                } else {
2300                    None
2301                }
2302            })
2303            .collect();
2304
2305        if data.is_empty() {
2306            None
2307        } else {
2308            Some(data)
2309        }
2310    }
2311}