linalg_rs/matrix/mod.rs
1//! Makin with matrices in rust easier!
2//!
3//! For now, only basic operations are allowed, but more are to be added
4//!
5//! This file is sub 1500 lines and acts as the core file
6
7mod helper;
8mod optim;
9
10use helper::*;
11
12use serde::{Deserialize, Serialize};
13use std::{
14 error::Error,
15 fmt::{Debug, Display},
16 fs,
17 marker::PhantomData,
18 ops::Div,
19 str::FromStr,
20};
21
22use itertools::{iproduct, Itertools};
23use num_traits::{pow, real::Real, sign::abs, Float};
24use rand::Rng;
25use rayon::prelude::*;
26use std::iter::Sum;
27
28use crate::{at, LinAlgFloats, MatrixElement, MatrixError, SparseMatrix};
29
30/// Shape represents the dimension size
31/// of the matrix as a tuple of usize
32pub type Shape = (usize, usize);
33
34/// Helper method to swap to usizes
35
36#[derive(Clone, PartialEq, PartialOrd, Debug, Serialize, Deserialize)]
37/// General dense matrix
38pub struct Matrix<'a, T>
39where
40 T: MatrixElement,
41 <T as FromStr>::Err: Error + 'static,
42 Vec<T>: IntoParallelIterator,
43 Vec<&'a T>: IntoParallelRefIterator<'a>,
44{
45 /// Vector containing all data
46 data: Vec<T>,
47 /// Number of rows
48 pub nrows: usize,
49 /// Number of columns
50 pub ncols: usize,
51 _lifetime: PhantomData<&'a T>,
52}
53
54impl<'a, T> Error for Matrix<'a, T>
55where
56 T: MatrixElement,
57 <T as FromStr>::Err: Error + 'static,
58 Vec<T>: IntoParallelIterator,
59 Vec<&'a T>: IntoParallelRefIterator<'a>,
60{
61}
62
63unsafe impl<'a, T> Send for Matrix<'a, T>
64where
65 T: MatrixElement,
66 <T as FromStr>::Err: Error + 'static,
67 Vec<T>: IntoParallelIterator,
68 Vec<&'a T>: IntoParallelRefIterator<'a>,
69{
70}
71
72unsafe impl<'a, T> Sync for Matrix<'a, T>
73where
74 T: MatrixElement,
75 <T as FromStr>::Err: Error + 'static,
76 Vec<T>: IntoParallelIterator,
77 Vec<&'a T>: IntoParallelRefIterator<'a>,
78{
79}
80
81impl<'a, T> FromStr for Matrix<'a, T>
82where
83 T: MatrixElement,
84 <T as FromStr>::Err: Error + 'static,
85 Vec<T>: IntoParallelIterator,
86 Vec<&'a T>: IntoParallelRefIterator<'a>,
87{
88 type Err = anyhow::Error;
89 fn from_str(s: &str) -> Result<Self, Self::Err> {
90 // Parse the input string and construct the matrix dynamically
91 let v: Vec<T> = s
92 .trim()
93 .lines()
94 .map(|l| {
95 l.split_whitespace()
96 .map(|num| num.parse::<T>().unwrap())
97 .collect::<Vec<T>>()
98 })
99 .collect::<Vec<Vec<T>>>()
100 .into_iter()
101 .flatten()
102 .collect();
103
104 let rows = s.trim().lines().count();
105 let cols = s.trim().lines().nth(0).unwrap().split_whitespace().count();
106
107 Ok(Self::new(v, (rows, cols)).unwrap())
108 }
109}
110
111impl<'a, T> Display for Matrix<'a, T>
112where
113 T: MatrixElement,
114 <T as FromStr>::Err: Error + 'static,
115 Vec<T>: IntoParallelIterator,
116 Vec<&'a T>: IntoParallelRefIterator<'a>,
117{
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(f, "[");
120
121 // Large matrices
122 if self.nrows > 10 || self.ncols > 10 {
123 write!(f, "...");
124 }
125
126 for i in 0..self.nrows {
127 for j in 0..self.ncols {
128 if i == 0 {
129 write!(f, "{:.4} ", self.get(i, j).unwrap());
130 } else {
131 write!(f, " {:.4}", self.get(i, j).unwrap());
132 }
133 }
134 // Print ] on same line if youre at the end
135 if i == self.nrows - 1 {
136 break;
137 }
138 writeln!(f);
139 }
140 writeln!(f, "], dtype={}", std::any::type_name::<T>())
141 }
142}
143
144impl<'a, T> Default for Matrix<'a, T>
145where
146 T: MatrixElement,
147 <T as FromStr>::Err: Error + 'static,
148 Vec<T>: IntoParallelIterator,
149 Vec<&'a T>: IntoParallelRefIterator<'a>,
150{
151 /// Represents a default identity matrix
152 ///
153 /// # Examples
154 ///
155 /// ```
156 /// use sukker::Matrix;
157 ///
158 /// let matrix: Matrix<f32> = Matrix::default();
159 ///
160 /// assert_eq!(matrix.size(), 9);
161 /// assert_eq!(matrix.shape(), (3,3));
162 /// ```
163 fn default() -> Self {
164 Self::eye(3)
165 }
166}
167
168/// Printer functions for the matrix
169impl<'a, T> Matrix<'a, T>
170where
171 T: MatrixElement,
172 <T as FromStr>::Err: Error + 'static,
173 Vec<T>: IntoParallelIterator,
174 Vec<&'a T>: IntoParallelRefIterator<'a>,
175{
176 /// Prints out the matrix with however many decimals you want
177 ///
178 /// # Examples
179 ///
180 /// ```
181 /// use sukker::Matrix;
182 ///
183 /// let matrix: Matrix<i32> = Matrix::eye(2);
184 /// matrix.print(4);
185 ///
186 /// ```
187 pub fn print(&self, decimals: usize) {
188 print!("[");
189
190 // Large matrices
191 if self.nrows > 10 || self.ncols > 10 {
192 print!("...");
193 }
194
195 for i in 0..self.nrows {
196 for j in 0..self.ncols {
197 if i == 0 {
198 print!(
199 "{val:.dec$} ",
200 dec = decimals,
201 val = self.get(i, j).unwrap()
202 );
203 } else {
204 print!(
205 " {val:.dec$}",
206 dec = decimals,
207 val = self.get(i, j).unwrap()
208 );
209 }
210 }
211 // Print ] on same line if youre at the end
212 if i == self.nrows - 1 {
213 break;
214 }
215 println!();
216 }
217 println!("], dtype={}", std::any::type_name::<T>());
218 }
219
220 /// Calculates sparsity of a given Matrix
221 ///
222 /// Examples:
223 ///
224 /// ```
225 /// use sukker::Matrix;
226 ///
227 /// let mat: Matrix<f32> = Matrix::eye(2);
228 ///
229 /// assert_eq!(mat.sparsity(), 0.5);
230 /// ```
231 #[inline(always)]
232 pub fn sparsity(&'a self) -> f64 {
233 self.count_where(|&e| e == T::zero()) as f64 / self.size() as f64
234 }
235
236 /// Returns the shape of a matrix represented as
237 /// (usize, usize)
238 ///
239 /// Examples:
240 ///
241 /// ```
242 /// use sukker::Matrix;
243 ///
244 /// let mat: Matrix<f32> = Matrix::eye(4);
245 ///
246 /// assert_eq!(mat.shape(), (4,4));
247 /// ```
248 pub fn shape(&self) -> Shape {
249 (self.nrows, self.ncols)
250 }
251}
252
253/// Implementations of all creatins of matrices
254impl<'a, T> Matrix<'a, T>
255where
256 T: MatrixElement,
257 <T as FromStr>::Err: Error,
258 Vec<T>: IntoParallelIterator,
259 Vec<&'a T>: IntoParallelRefIterator<'a>,
260{
261 /// Creates a new matrix from a vector and the shape you want.
262 /// Will default init if it does not work
263 ///
264 /// # Examples
265 ///
266 /// ```
267 /// use sukker::Matrix;
268 ///
269 /// let matrix = Matrix::new(vec![1.0,2.0,3.0,4.0], (2,2)).unwrap();
270 ///
271 /// assert_eq!(matrix.size(), 4);
272 /// assert_eq!(matrix.shape(), (2,2));
273 /// ```
274 pub fn new(data: Vec<T>, shape: Shape) -> Result<Self, MatrixError> {
275 if shape.0 * shape.1 != data.len() {
276 return Err(MatrixError::MatrixCreationError.into());
277 }
278
279 Ok(Self {
280 data,
281 nrows: shape.0,
282 ncols: shape.1,
283 _lifetime: PhantomData::default(),
284 })
285 }
286
287 /// Initializes a matrix with the same value
288 /// given from parameter 'value'
289 ///
290 /// # Examples
291 ///
292 /// ```
293 /// use sukker::Matrix;
294 ///
295 /// let matrix = Matrix::init(4f32, (1,2));
296 ///
297 /// assert_eq!(matrix.get_vec(), vec![4f32,4f32]);
298 /// assert_eq!(matrix.shape(), (1,2));
299 /// ```
300 pub fn init(value: T, shape: Shape) -> Self {
301 Self::from_shape(value, shape)
302 }
303
304 /// Returns an eye matrix which for now is the same as the
305 /// identity matrix
306 ///
307 /// # Examples
308 ///
309 /// ```
310 /// use sukker::Matrix;
311 ///
312 /// let matrix: Matrix<f32> = Matrix::eye(2);
313 ///
314 /// assert_eq!(matrix.get_vec(), vec![1f32, 0f32, 0f32, 1f32]);
315 /// assert_eq!(matrix.shape(), (2,2));
316 /// ```
317 pub fn eye(size: usize) -> Self {
318 let mut data: Vec<T> = vec![T::zero(); size * size];
319
320 (0..size).for_each(|i| data[i * size + i] = T::one());
321
322 // Safe to do since the library is setting the size
323 Self::new(data, (size, size)).unwrap()
324 }
325
326 /// Produce an identity matrix in the same shape as
327 /// an already existent matrix
328 ///
329 /// Examples
330 ///
331 /// ```
332 ///
333 ///
334 /// ```
335 pub fn eye_like(matrix: &Self) -> Self {
336 Self::eye(matrix.nrows)
337 }
338
339 /// Identity is same as eye, just for nerds
340 ///
341 /// # Examples
342 ///
343 /// ```
344 /// use sukker::Matrix;
345 ///
346 /// let matrix: Matrix<i64> = Matrix::identity(2);
347 ///
348 /// assert_eq!(matrix.get_vec(), vec![1i64, 0i64, 0i64, 1i64]);
349 /// assert_eq!(matrix.shape(), (2,2));
350 /// ```
351 pub fn identity(size: usize) -> Self {
352 Self::eye(size)
353 }
354
355 /// Tries to create a matrix from a slize and shape
356 ///
357 /// # Examples
358 ///
359 /// ```
360 /// use sukker::Matrix;
361 ///
362 /// let s = vec![1f32, 2f32, 3f32, 4f32];
363 /// let matrix = Matrix::from_slice(&s, (4,1)).unwrap();
364 ///
365 /// assert_eq!(matrix.shape(), (4,1));
366 /// ```
367 pub fn from_slice(arr: &[T], shape: Shape) -> Result<Self, MatrixError> {
368 if shape.0 * shape.1 != arr.len() {
369 return Err(MatrixError::MatrixCreationError.into());
370 }
371
372 Ok(Self::new(arr.to_owned(), shape).unwrap())
373 }
374
375 /// Creates a matrix where all values are 0.
376 /// All sizes are based on a shape
377 ///
378 /// # Examples
379 ///
380 /// ```
381 /// use sukker::Matrix;
382 ///
383 /// let matrix: Matrix<f64> = Matrix::zeros((4,1));
384 ///
385 /// assert_eq!(matrix.shape(), (4,1));
386 /// assert_eq!(matrix.get(0,0).unwrap(), 0f64);
387 /// assert_eq!(matrix.size(), 4);
388 /// ```
389 pub fn zeros(shape: Shape) -> Self {
390 Self::from_shape(T::zero(), shape)
391 }
392
393 /// Creates a matrix where all values are 1.
394 /// All sizes are based on a shape
395 ///
396 /// # Examples
397 ///
398 /// ```
399 /// use sukker::Matrix;
400 ///
401 /// let matrix: Matrix<f64> = Matrix::ones((4,1));
402 ///
403 /// assert_eq!(matrix.shape(), (4,1));
404 /// assert_eq!(matrix.get(0,0).unwrap(), 1f64);
405 /// assert_eq!(matrix.size(), 4);
406 /// ```
407 pub fn ones(shape: Shape) -> Self {
408 Self::from_shape(T::one(), shape)
409 }
410
411 /// Creates a matrix where all values are 0.
412 /// All sizes are based on an already existent matrix
413 ///
414 /// # Examples
415 ///
416 /// ```
417 /// use sukker::Matrix;
418 ///
419 /// let matrix1: Matrix<i8> = Matrix::default();
420 /// let matrix2 = Matrix::zeros_like(&matrix1);
421 ///
422 /// assert_eq!(matrix2.shape(), matrix1.shape());
423 /// assert_eq!(matrix2.get(0,0).unwrap(), 0i8);
424 /// ```
425 pub fn zeros_like(other: &Self) -> Self {
426 Self::from_shape(T::zero(), other.shape())
427 }
428
429 /// Creates a matrix where all values are 1.
430 /// All sizes are based on an already existent matrix
431 ///
432 /// # Examples
433 ///
434 /// ```
435 /// use sukker::Matrix;
436 ///
437 /// let matrix1: Matrix<i64> = Matrix::default();
438 /// let matrix2 = Matrix::ones_like(&matrix1);
439 ///
440 /// assert_eq!(matrix2.shape(), matrix1.shape());
441 /// assert_eq!(1i64, matrix2.get(0,0).unwrap());
442 /// ```
443 pub fn ones_like(other: &Self) -> Self {
444 Self::from_shape(T::one(), other.shape())
445 }
446
447 /// Creates a matrix where all values are random between 0 and 1.
448 /// All sizes are based on an already existent matrix
449 ///
450 /// # Examples
451 ///
452 /// ```
453 /// use sukker::Matrix;
454 ///
455 /// let matrix1: Matrix<f32> = Matrix::default();
456 /// let matrix2 = Matrix::random_like(&matrix1);
457 ///
458 /// assert_eq!(matrix1.shape(), matrix2.shape());
459 ///
460 ///
461 /// ```
462 pub fn random_like(matrix: &Self) -> Self {
463 Self::randomize_range(T::zero(), T::one(), matrix.shape())
464 }
465
466 /// Creates a matrix where all values are random between start..=end.
467 /// Shape in new array is given through parameter 'shape'
468 ///
469 /// # Examples
470 ///
471 /// ```
472 /// use sukker::Matrix;
473 ///
474 /// let matrix = Matrix::randomize_range(1f32, 2f32, (2,3));
475 /// let elem = matrix.get(1,1).unwrap();
476 ///
477 /// assert_eq!(matrix.shape(), (2,3));
478 /// //assert!(elem >= 1f32 && 2f32 <= elem);
479 /// ```
480 pub fn randomize_range(start: T, end: T, shape: Shape) -> Self {
481 let mut rng = rand::thread_rng();
482
483 let (rows, cols) = shape;
484
485 let len: usize = rows * cols;
486
487 let data: Vec<T> = (0..len).map(|_| rng.gen_range(start..=end)).collect();
488
489 // Safe because shape doesn't have to match data from a user
490 Self::new(data, shape).unwrap()
491 }
492
493 /// Creates a matrix where all values are random between 0..=1.
494 /// Shape in new array is given through parameter 'shape'
495 ///
496 /// # Examples
497 ///
498 /// ```
499 /// use sukker::Matrix;
500 ///
501 /// let matrix: Matrix<f64> = Matrix::randomize((2,3));
502 ///
503 /// assert_eq!(matrix.shape(), (2,3));
504 /// ```
505 pub fn randomize(shape: Shape) -> Self {
506 Self::randomize_range(T::zero(), T::one(), shape)
507 }
508
509 /// Parses from file, but will return a default matrix if nothing is given
510 ///
511 /// # Examples
512 ///
513 /// ```
514 /// use sukker::Matrix;
515 ///
516 /// // let m: Matrix<f32> = Matrix::from_file("../../test.txt").unwrap();
517 ///
518 /// // m.print(4);
519 /// ```
520 pub fn from_file(path: &'static str) -> Result<Self, MatrixError> {
521 let data =
522 fs::read_to_string(path).map_err(|_| MatrixError::MatrixFileReadError(path).into())?;
523
524 data.parse::<Self>()
525 .map_err(|_| MatrixError::MatrixParseError.into())
526 }
527
528 /// Constructs a new dense matrix from a sparse one.
529 ///
530 /// This transfesrs ownership as well!
531 ///
532 /// Examples
533 ///
534 /// ```
535 /// use sukker::{Matrix, SparseMatrix};
536 ///
537 /// let sparse = SparseMatrix::<i32>::eye(3);
538 ///
539 /// let matrix = Matrix::from_sparse(sparse);
540 ///
541 /// assert_eq!(matrix.shape(), (3,3));
542 /// assert_eq!(matrix.at(0,0), 1);
543 /// ```
544 pub fn from_sparse(sparse: SparseMatrix<'a, T>) -> Self {
545 let mut mat = Self::zeros(sparse.shape());
546
547 for (&idx, &val) in sparse.data.iter() {
548 mat.set(val, idx);
549 }
550
551 mat
552 }
553
554 /// Helper function to create matrices
555 fn from_shape(value: T, shape: Shape) -> Self {
556 let (rows, cols) = shape;
557
558 let len: usize = rows * cols;
559
560 let data = vec![value; len];
561
562 Self::new(data, shape).unwrap()
563 }
564}
565
566/// Enum for specifying which dimension / axis to work with
567pub enum Dimension {
568 /// Row is defined as 0
569 Row = 0,
570 /// Col is defined as 1
571 Col = 1,
572}
573
574/// Regular matrix methods that are not operating math on them
575impl<'a, T> Matrix<'a, T>
576where
577 T: MatrixElement + Div<Output = T> + Sum<T>,
578 <T as FromStr>::Err: Error + 'static,
579 Vec<T>: IntoParallelIterator,
580 Vec<&'a T>: IntoParallelRefIterator<'a>,
581{
582 /// Reshapes a matrix if possible.
583 /// If the shapes don't match up, the old shape will be retained
584 ///
585 /// # Examples
586 ///
587 /// ```
588 /// use sukker::Matrix;
589 ///
590 /// let mut matrix = Matrix::init(10.5, (2,3));
591 /// matrix.reshape(3,2);
592 ///
593 /// assert_eq!(matrix.shape(), (3,2));
594 /// ```
595 pub fn reshape(&mut self, nrows: usize, ncols: usize) {
596 if nrows * ncols != self.size() {
597 eprintln!("Err: Can not reshape.. Keeping old dimensions for now");
598 return;
599 }
600
601 self.nrows = nrows;
602 self.ncols = ncols;
603 }
604
605 /// Get the total size of the matrix
606 ///
607 /// # Examples
608 ///
609 /// ```
610 /// use sukker::Matrix;
611 ///
612 /// let matrix = Matrix::init(10.5, (2,3));
613 ///
614 /// assert_eq!(matrix.size(), 6);
615 /// ```
616 pub fn size(&self) -> usize {
617 self.nrows * self.ncols
618 }
619
620 /// Gets element based on is and js
621 ///
622 /// # Examples
623 ///
624 /// ```
625 /// use sukker::Matrix;
626 ///
627 /// let matrix = Matrix::init(10.5f32, (2,3));
628 ///
629 /// assert_eq!(matrix.get(0,1).unwrap(), 10.5f32);
630 /// ```
631 pub fn get(&self, i: usize, j: usize) -> Option<T> {
632 let idx = at!(i, j, self.ncols);
633
634 if idx >= self.size() {
635 return None;
636 }
637
638 Some(self.at(i, j))
639 }
640
641 /// Gets element based on is and js, but will
642 /// panic if indexes are out of range.
643 ///
644 /// # Examples
645 ///
646 /// ```
647 /// use sukker::Matrix;
648 ///
649 /// let val = 10.5;
650 ///
651 /// let matrix = Matrix::init(val, (2,3));
652 ///
653 /// assert_eq!(matrix.at(1,2), val);
654 /// ```
655 #[inline(always)]
656 pub fn at(&self, i: usize, j: usize) -> T {
657 self.data[at!(i, j, self.ncols)]
658 }
659
660 /// Gets a piece of the matrix out as a vector
661 ///
662 /// If some indeces are out of bounds, the vec up until that point
663 /// will be returned
664 ///
665 /// # Examples
666 ///
667 /// ```
668 /// use sukker::Matrix;
669 ///
670 /// let matrix = Matrix::init(10.5, (4,4));
671 /// let slice = matrix.get_vec_slice((1,1), (2,2));
672 ///
673 /// assert_eq!(slice, vec![10.5,10.5,10.5,10.5]);
674 /// ```
675 pub fn get_vec_slice(&self, start_idx: Shape, size: Shape) -> Vec<T> {
676 let (start_row, start_col) = start_idx;
677 let (dx, dy) = size;
678
679 iproduct!(start_row..start_row + dy, start_col..start_col + dx)
680 .filter_map(|(i, j)| self.get(i, j))
681 .collect()
682 }
683
684 /// Gets you the whole entire matrix as a vector
685 ///
686 /// # Examples
687 ///
688 /// ```
689 /// use sukker::Matrix;
690 ///
691 /// let matrix = Matrix::init(10.5, (4,4));
692 /// let slice = matrix.get_vec_slice((1,1), (2,2));
693 ///
694 /// assert_eq!(slice, vec![10.5,10.5,10.5,10.5]);
695 /// ```
696 pub fn get_vec(&self) -> Vec<T> {
697 self.data.clone()
698 }
699
700 /// Gets a piece of the matrix out as a matrix
701 ///
702 /// If some indeces are out of bounds, unlike `get_vec_slice`
703 /// this function will return an IndexOutOfBoundsError
704 /// and will not return data
705 ///
706 /// # Examples
707 ///
708 /// ```
709 /// use sukker::Matrix;
710 ///
711 /// let matrix = Matrix::init(10.5, (4,4));
712 /// let sub_matrix = matrix.get_sub_matrix((1,1), (2,2)).unwrap();
713 ///
714 /// assert_eq!(sub_matrix.get_vec(), vec![10.5,10.5,10.5,10.5]);
715 /// ```
716 pub fn get_sub_matrix(&self, start_idx: Shape, size: Shape) -> Result<Self, MatrixError> {
717 let (start_row, start_col) = start_idx;
718 let (dx, dy) = size;
719
720 let data = iproduct!(start_row..start_row + dy, start_col..start_col + dx)
721 .filter_map(|(i, j)| self.get(i, j))
722 .collect();
723
724 return match Self::new(data, size) {
725 Ok(a) => Ok(a),
726 Err(_) => Err(MatrixError::MatrixIndexOutOfBoundsError.into()),
727 };
728 }
729
730 /// Concat two mtrices on a dimension
731 ///
732 /// # Examples
733 ///
734 /// ```
735 /// use sukker::Matrix;
736 /// use sukker::Dimension;
737 ///
738 /// let matrix = Matrix::init(10.5, (4,4));
739 /// let matrix2 = Matrix::init(10.5, (1,4));
740 ///
741 /// let res = matrix.concat(&matrix2, Dimension::Row).unwrap();
742 ///
743 /// assert_eq!(res.shape(), (5,4));
744 /// ```
745 pub fn concat(&self, other: &Self, dim: Dimension) -> Result<Self, MatrixError> {
746 match dim {
747 Dimension::Row => {
748 if self.ncols != other.ncols {
749 return Err(MatrixError::MatrixConcatinationError.into());
750 }
751
752 let mut new_data = self.data.clone();
753
754 new_data.extend(other.data.iter());
755
756 let nrows = self.nrows + other.nrows;
757 let shape = (nrows, self.ncols);
758
759 return Ok(Self::new(new_data, shape).unwrap());
760 }
761
762 Dimension::Col => {
763 if self.nrows != other.nrows {
764 return Err(MatrixError::MatrixConcatinationError.into());
765 }
766
767 let mut new_data: Vec<T> = Vec::new();
768
769 let take_self = self.ncols;
770 let take_other = other.ncols;
771
772 for (idx, _) in self.data.iter().step_by(take_self).enumerate() {
773 // Add from self, then other
774 let row = (idx / take_self) * take_self;
775 new_data.extend(self.data.iter().skip(row).take(take_self));
776 new_data.extend(other.data.iter().skip(row).take(take_other));
777 }
778
779 let ncols = self.ncols + other.ncols;
780 let shape = (self.nrows, ncols);
781
782 return Ok(Self::new(new_data, shape).unwrap());
783 }
784 };
785 }
786
787 // TODO: Add option to transpose to be able to extend
788 // Doens't change anything if dimension mismatch
789
790 /// Extend a matrix with another on a dimension
791 ///
792 /// # Examples
793 ///
794 /// ```
795 /// use sukker::Matrix;
796 /// use sukker::Dimension;
797 ///
798 /// let mut matrix = Matrix::init(10.5, (4,4));
799 /// let matrix2 = Matrix::init(10.5, (4,1));
800 ///
801 /// matrix.extend(&matrix2, Dimension::Col);
802 ///
803 /// assert_eq!(matrix.shape(), (4,5));
804 /// ```
805 pub fn extend(&mut self, other: &Self, dim: Dimension) {
806 match dim {
807 Dimension::Row => {
808 if self.ncols != other.ncols {
809 eprintln!("Error: Dimension mismatch");
810 return;
811 }
812
813 self.data.extend(other.data.iter());
814
815 self.nrows += other.nrows;
816 }
817
818 Dimension::Col => {
819 if self.nrows != other.nrows {
820 eprintln!("Error: Dimension mismatch");
821 return;
822 }
823
824 let mut new_data: Vec<T> = Vec::new();
825
826 let take_self = self.ncols;
827 let take_other = other.ncols;
828
829 for (idx, _) in self.data.iter().step_by(take_self).enumerate() {
830 // Add from self, then other
831 let row = (idx / take_self) * take_self;
832 new_data.extend(self.data.iter().skip(row).take(take_self));
833 new_data.extend(other.data.iter().skip(row).take(take_other));
834 }
835
836 self.ncols += other.ncols;
837 }
838 };
839 }
840
841 /// Sets element based on is and js
842 ///
843 /// Sets nothing if you;re out of bounds
844 ///
845 /// # Examples
846 ///
847 /// ```
848 /// use sukker::Matrix;
849 ///
850 /// let mut matrix = Matrix::init(10.5, (2,3));
851 /// matrix.set(11.5, (1, 2));
852 ///
853 /// assert_eq!(matrix.get(1,2).unwrap(), 11.5);
854 /// ```
855 pub fn set(&mut self, value: T, idx: Shape) {
856 let idx = at!(idx.0, idx.1, self.ncols);
857
858 if idx >= self.size() {
859 eprintln!("Error: Index out of bounds. Not setting value.");
860 return;
861 }
862
863 self.data[idx] = value;
864 }
865
866 /// Sets many elements based on vector of indeces
867 ///
868 /// For indexes out of bounds, nothing is set
869 ///
870 /// # Examples
871 ///
872 /// ```
873 /// use sukker::Matrix;
874 ///
875 /// let mut matrix = Matrix::init(10.5, (2,3));
876 /// matrix.set_many(vec![(1,2), (1,1)], 11.5);
877 ///
878 /// assert_eq!(matrix.get(1,2).unwrap(), 11.5);
879 /// assert_eq!(matrix.get(1,1).unwrap(), 11.5);
880 /// assert_eq!(matrix.get(0,1).unwrap(), 10.5);
881 /// ```
882 pub fn set_many(&mut self, idx_list: Vec<Shape>, value: T) {
883 idx_list.iter().for_each(|&idx| self.set(value, idx));
884 }
885
886 /// Sets all elements of a matrix in a 1d range.
887 ///
888 /// The range is inclusive to stop, and will panic
889 /// if any indexes are out of range
890 ///
891 /// # Examples
892 ///
893 /// ```
894 /// use sukker::Matrix;
895 ///
896 /// let mut matrix = Matrix::init(10.5, (2,3));
897 /// matrix.set_range(0, 3, 11.5);
898 ///
899 /// assert_eq!(matrix.get(0,2).unwrap(), 11.5);
900 /// assert_eq!(matrix.get(0,1).unwrap(), 11.5);
901 /// assert_eq!(matrix.get(1,1).unwrap(), 10.5);
902 /// ```
903 pub fn set_range(&mut self, start: usize, stop: usize, value: T) {
904 (start..=stop).for_each(|i| self.data[i] = value);
905 }
906
907 /// Calculates the (row, col) for a matrix by a single index
908 ///
909 /// # Examples
910 ///
911 /// ```
912 /// use sukker::Matrix;
913 ///
914 /// let matrix = Matrix::init(10.5, (2,2));
915 /// let inv = matrix.one_to_2d_idx(1);
916 ///
917 /// assert_eq!(inv, (0,1));
918 /// ```
919 pub fn one_to_2d_idx(&self, idx: usize) -> Shape {
920 let row = idx / self.ncols;
921 let col = idx % self.ncols;
922
923 (row, col)
924 }
925
926 /// Finds maximum element in the matrix
927 ///
928 /// # Examples
929 ///
930 /// ```
931 /// use sukker::Matrix;
932 ///
933 /// let matrix = Matrix::init(10.5, (2,3));
934 ///
935 /// assert_eq!(matrix.max(), 10.5);
936 /// ```
937 pub fn max(&self) -> T {
938 // Matrix must have at least one element, thus we can unwrap
939 *self
940 .data
941 .par_iter()
942 .max_by(|a, b| a.partial_cmp(b).unwrap())
943 .unwrap()
944 }
945
946 /// Finds minimum element in the matrix
947 ///
948 /// # Examples
949 ///
950 /// ```
951 /// use sukker::Matrix;
952 ///
953 /// let mut matrix = Matrix::init(10.5, (2,3));
954 /// matrix.set(1.0, (0,2));
955 ///
956 /// assert_eq!(matrix.min(), 1.0);
957 /// ```
958 pub fn min(&self) -> T {
959 // Matrix must have at least one element, thus we can unwrap
960 *self
961 .data
962 .par_iter()
963 .min_by(|a, b| a.partial_cmp(b).unwrap())
964 .unwrap()
965 }
966
967 /// Finds position in matrix where value is highest.
968 /// Restricted to find this across a row or column
969 /// in the matrix.
970 ///
971 /// # Examples
972 ///
973 /// ```
974 /// use sukker::{Matrix, Dimension};
975 ///
976 /// let mut matrix = Matrix::init(1.0, (3,3));
977 /// matrix.set(15.0, (0,2));
978 ///
979 /// ```
980 fn argmax(&self, rowcol: usize, dimension: Dimension) -> Option<Shape> {
981 match dimension {
982 Dimension::Row => {
983 if rowcol >= self.nrows - 1 {
984 return None;
985 }
986
987 let mut highest: T = T::one();
988 let mut i = 0;
989
990 for (idx, elem) in self
991 .data
992 .iter()
993 .enumerate()
994 .skip(rowcol * self.ncols)
995 .take(self.ncols)
996 {
997 if *elem >= highest {
998 i = idx;
999 }
1000 }
1001
1002 Some(self.one_to_2d_idx(i))
1003 }
1004
1005 Dimension::Col => {
1006 if rowcol >= self.ncols - 1 {
1007 return None;
1008 }
1009
1010 let mut highest: T = T::one();
1011
1012 let mut i = 0;
1013
1014 for (idx, elem) in self
1015 .data
1016 .iter()
1017 .enumerate()
1018 .skip(rowcol)
1019 .step_by(self.ncols)
1020 {
1021 if *elem >= highest {
1022 i = idx;
1023 }
1024 }
1025
1026 Some(self.one_to_2d_idx(i))
1027 }
1028 }
1029 }
1030
1031 /// Finds position in matrix where value is lowest.
1032 /// Restricted to find this across a row or column
1033 /// in the matrix.
1034 ///
1035 ///
1036 /// # Examples
1037 ///
1038 /// ```
1039 /// use sukker::{Matrix, Dimension};
1040 ///
1041 /// let mut matrix = Matrix::init(10.5, (3,3));
1042 /// matrix.set(1.0, (0,1));
1043 ///
1044 /// // assert_eq!(matrix.argmin(1, Dimension::Col), Some(1));
1045 /// ```
1046 fn argmin(&self, rowcol: usize, dimension: Dimension) -> Option<Shape> {
1047 match dimension {
1048 Dimension::Row => {
1049 if rowcol >= self.nrows - 1 {
1050 return None;
1051 }
1052
1053 let mut lowest: T = T::zero();
1054
1055 let mut i = 0;
1056
1057 for (idx, elem) in self
1058 .data
1059 .iter()
1060 .enumerate()
1061 .skip(rowcol * self.ncols)
1062 .take(self.ncols)
1063 {
1064 if *elem < lowest {
1065 i = idx;
1066 }
1067 }
1068
1069 Some(self.one_to_2d_idx(i))
1070 }
1071
1072 Dimension::Col => {
1073 if rowcol >= self.ncols - 1 {
1074 return None;
1075 }
1076
1077 let mut lowest: T = T::zero();
1078
1079 let mut i = 0;
1080
1081 for (idx, elem) in self
1082 .data
1083 .iter()
1084 .enumerate()
1085 .skip(rowcol)
1086 .step_by(self.ncols)
1087 {
1088 if *elem <= lowest {
1089 i = idx;
1090 }
1091 }
1092
1093 Some(self.one_to_2d_idx(i))
1094 }
1095 }
1096 }
1097
1098 /// Finds total sum of matrix
1099 ///
1100 /// # Examples
1101 ///
1102 /// ```
1103 /// use sukker::Matrix;
1104 ///
1105 /// let matrix = Matrix::init(10f32, (2,2));
1106 ///
1107 /// assert_eq!(matrix.cumsum(), 40.0);
1108 /// ```
1109 pub fn cumsum(&self) -> T {
1110 if self.size() == 0 {
1111 return T::zero();
1112 }
1113
1114 self.data.par_iter().copied().sum()
1115 }
1116
1117 /// Multiplies all elements in matrix
1118 ///
1119 /// # Examples
1120 ///
1121 /// ```
1122 /// use sukker::Matrix;
1123 ///
1124 /// let matrix = Matrix::init(10f32, (2,2));
1125 ///
1126 /// assert_eq!(matrix.cumprod(), 10000.0);
1127 /// ```
1128 pub fn cumprod(&self) -> T {
1129 if self.size() == 0 {
1130 return T::zero();
1131 }
1132
1133 self.data.par_iter().copied().product()
1134 }
1135
1136 /// Gets the average of the matrix
1137 ///
1138 /// # Examples
1139 ///
1140 /// ```
1141 /// use sukker::Matrix;
1142 ///
1143 /// let matrix = Matrix::init(10f32, (2,2));
1144 ///
1145 /// assert_eq!(matrix.avg(), 10.0);
1146 /// ```
1147 pub fn avg(&self) -> T {
1148 self.data.par_iter().copied().sum::<T>() / self.size().to_string().parse::<T>().unwrap()
1149 }
1150
1151 /// Gets the mean of the matrix
1152 ///
1153 /// # Examples
1154 ///
1155 /// ```
1156 /// use sukker::Matrix;
1157 ///
1158 /// let matrix = Matrix::init(10f32, (2,2));
1159 ///
1160 /// assert_eq!(matrix.mean(), 10.0);
1161 /// ```
1162 pub fn mean(&self) -> T {
1163 self.avg()
1164 }
1165
1166 /// Gets the median of the matrix
1167 ///
1168 /// # Examples
1169 ///
1170 /// ```
1171 /// use sukker::Matrix;
1172 ///
1173 /// let matrix = Matrix::new(vec![1.0, 4.0, 6.0, 5.0], (2,2)).unwrap();
1174 ///
1175 /// assert!(matrix.median() >= 4.45 && matrix.median() <= 4.55);
1176 /// ```
1177 pub fn median(&self) -> T {
1178 if self.size() == 1 {
1179 return self.at(0, 0);
1180 }
1181
1182 match self.data.len() % 2 {
1183 0 => {
1184 let half: usize = self.data.len() / 2;
1185
1186 self.data
1187 .iter()
1188 .sorted_by(|a, b| a.partial_cmp(&b).unwrap())
1189 .skip(half - 1)
1190 .take(2)
1191 .copied()
1192 .sum::<T>()
1193 / (T::one() + T::one())
1194 }
1195 1 => {
1196 let half: usize = self.data.len() / 2;
1197
1198 self.data
1199 .iter()
1200 .sorted_by(|a, b| a.partial_cmp(&b).unwrap())
1201 .nth(half)
1202 .copied()
1203 .unwrap()
1204 }
1205 _ => unreachable!(),
1206 }
1207 }
1208
1209 /// Sums up elements over given axis and dimension.
1210 /// Will return 0 if you're out of bounds
1211 ///
1212 /// sum(2, Dimension::Col) means summig up these ones
1213 ///
1214 /// [ 10 10 (10) 10 10
1215 /// 10 10 (10) 10 10
1216 /// 10 10 (10) 10 10
1217 /// 10 10 (10) 10 10
1218 /// 10 10 (10) 10 10 ]
1219 ///
1220 /// = 10 * 5 = 50
1221 ///
1222 /// # Examples
1223 ///
1224 /// ```
1225 /// use sukker::Matrix;
1226 /// use sukker::Dimension;
1227 ///
1228 /// let matrix = Matrix::init(10f32, (5,5));
1229 ///
1230 /// assert_eq!(matrix.sum(0, Dimension::Row), 50.0);
1231 /// assert_eq!(matrix.sum(3, Dimension::Col), 50.0);
1232 /// ```
1233 pub fn sum(&self, rowcol: usize, dimension: Dimension) -> T {
1234 // TODO: Add out of bounds options
1235 if self.size() == 1 {
1236 return self.at(0, 0);
1237 }
1238
1239 match dimension {
1240 Dimension::Row => self
1241 .data
1242 .par_iter()
1243 .skip(rowcol * self.ncols)
1244 .take(self.ncols)
1245 .copied()
1246 .sum(),
1247 Dimension::Col => self
1248 .data
1249 .par_iter()
1250 .skip(rowcol)
1251 .step_by(self.ncols)
1252 .copied()
1253 .sum(),
1254 }
1255 }
1256
1257 /// Prods up elements over given rowcol and dimension
1258 /// Will return 1 if you're out of bounds.
1259 ///
1260 /// See `sum` for example on how this is calculated
1261 ///
1262 /// # Examples
1263 ///
1264 /// ```
1265 /// use sukker::Matrix;
1266 /// use sukker::Dimension;
1267 ///
1268 /// let matrix = Matrix::init(10f32, (2,2));
1269 ///
1270 /// assert_eq!(matrix.prod(0, Dimension::Row), 100.0);
1271 /// assert_eq!(matrix.prod(0, Dimension::Col), 100.0);
1272 /// ```
1273 pub fn prod(&self, rowcol: usize, dimension: Dimension) -> T {
1274 match dimension {
1275 Dimension::Row => self
1276 .data
1277 .par_iter()
1278 .skip(rowcol * self.ncols)
1279 .take(self.ncols)
1280 .copied()
1281 .product(),
1282 Dimension::Col => self
1283 .data
1284 .par_iter()
1285 .skip(rowcol)
1286 .step_by(self.ncols)
1287 .copied()
1288 .product(),
1289 }
1290 }
1291}
1292
1293/// Linalg on floats
1294impl<'a, T> LinAlgFloats<'a, T> for Matrix<'a, T>
1295where
1296 T: MatrixElement + Float + 'a,
1297 <T as FromStr>::Err: Error + 'static,
1298 Vec<T>: IntoParallelIterator,
1299 Vec<&'a T>: IntoParallelRefIterator<'a>,
1300{
1301 /// Takes the logarithm of each element
1302 ///
1303 /// # Examples
1304 ///
1305 /// ```
1306 /// use sukker::{Matrix, LinAlgFloats};
1307 ///
1308 /// let matrix = Matrix::init(10.0, (2,2));
1309 /// let result = matrix.log(10.0);
1310 ///
1311 /// assert_eq!(result.all(|&e| e == 1.0), true);
1312 ///
1313 /// ```
1314 fn log(&self, base: T) -> Self {
1315 let data: Vec<T> = self.data.par_iter().map(|&e| e.log(base)).collect();
1316
1317 Self::new(data, self.shape()).unwrap()
1318 }
1319
1320 /// Takes the natural logarithm of each element in a matrix
1321 ///
1322 /// # Examples
1323 ///
1324 /// ```
1325 /// use sukker::{Matrix, LinAlgFloats};
1326 /// use sukker::constants::EF64;
1327 ///
1328 /// let matrix: Matrix<f64> = Matrix::init(EF64, (2,2));
1329 ///
1330 /// let res = matrix.ln();
1331 /// ```
1332 fn ln(&self) -> Self {
1333 let data: Vec<T> = self.data.par_iter().map(|&e| e.ln()).collect();
1334
1335 Self::new(data, self.shape()).unwrap()
1336 }
1337
1338 /// Takes the square root of each element in a matrix.
1339 /// If some elements are negative, these will be kept the same
1340 ///
1341 /// # Examples
1342 ///
1343 /// ```
1344 /// use sukker::{Matrix, LinAlgFloats};
1345 ///
1346 /// let matrix: Matrix<f64> = Matrix::init(9.0, (3,3));
1347 ///
1348 /// let res = matrix.sqrt();
1349 ///
1350 /// assert_eq!(res.all(|&e| e == 3.0), true);
1351 /// ```
1352 fn sqrt(&self) -> Self {
1353 let data: Vec<T> = self
1354 .data
1355 .par_iter()
1356 .map(|&e| if e > T::zero() { e.sqrt() } else { e })
1357 .collect();
1358
1359 Self::new(data, self.shape()).unwrap()
1360 }
1361
1362 /// Gets sin of every value
1363 ///
1364 /// # Examples
1365 ///
1366 /// ```
1367 /// use sukker::{Matrix, LinAlgFloats};
1368 ///
1369 /// let matrix = Matrix::init(1.0, (2,2));
1370 ///
1371 /// let res = matrix.sin();
1372 /// ```
1373 fn sin(&self) -> Self {
1374 let data: Vec<T> = self.data.par_iter().map(|&e| e.sin()).collect();
1375
1376 Self::new(data, self.shape()).unwrap()
1377 }
1378
1379 /// Gets cos of every value
1380 ///
1381 /// # Examples
1382 ///
1383 /// ```
1384 /// use sukker::{Matrix, LinAlgFloats};
1385 /// use sukker::constants::EF32;
1386 ///
1387 /// let matrix = Matrix::init(EF32, (2,2));
1388 ///
1389 /// let res = matrix.cos();
1390 /// ```
1391 fn cos(&self) -> Self {
1392 let data: Vec<T> = self.data.par_iter().map(|&e| e.cos()).collect();
1393
1394 Self::new(data, self.shape()).unwrap()
1395 }
1396
1397 /// Gets tan of every value
1398 ///
1399 /// # Examples
1400 ///
1401 /// ```
1402 /// use sukker::{Matrix, LinAlgFloats};
1403 /// use sukker::constants::EF32;
1404 ///
1405 /// let matrix = Matrix::init(EF32, (2,2));
1406 ///
1407 /// let res = matrix.tan();
1408 /// ```
1409 fn tan(&self) -> Self {
1410 let data: Vec<T> = self.data.par_iter().map(|&e| e.tan()).collect();
1411
1412 Self::new(data, self.shape()).unwrap()
1413 }
1414
1415 /// Gets sinh of every value
1416 ///
1417 /// # Examples
1418 ///
1419 /// ```
1420 /// use sukker::{Matrix, LinAlgFloats};
1421 /// use sukker::constants::EF32;
1422 ///
1423 /// let matrix = Matrix::init(EF32, (2,2));
1424 ///
1425 /// let res = matrix.sinh();
1426 /// ```
1427 fn sinh(&self) -> Self {
1428 let data: Vec<T> = self.data.par_iter().map(|&e| e.sinh()).collect();
1429
1430 Self::new(data, self.shape()).unwrap()
1431 }
1432
1433 /// Gets cosh of every value
1434 ///
1435 /// # Examples
1436 ///
1437 /// ```
1438 /// use sukker::{Matrix, LinAlgFloats};
1439 /// use sukker::constants::EF32;
1440 ///
1441 /// let matrix = Matrix::init(EF32, (2,2));
1442 ///
1443 /// let res = matrix.cosh();
1444 /// ```
1445 fn cosh(&self) -> Self {
1446 let data: Vec<T> = self.data.par_iter().map(|&e| e.cosh()).collect();
1447
1448 Self::new(data, self.shape()).unwrap()
1449 }
1450
1451 /// Gets tanh of every value
1452 ///
1453 /// # Examples
1454 ///
1455 /// ```
1456 /// use sukker::{Matrix, LinAlgFloats};
1457 /// use sukker::constants::EF32;
1458 ///
1459 /// let matrix = Matrix::init(EF32, (2,2));
1460 ///
1461 /// let res = matrix.tanh();
1462 /// ```
1463 fn tanh(&self) -> Self {
1464 let data: Vec<T> = self.data.par_iter().map(|&e| e.tanh()).collect();
1465
1466 Self::new(data, self.shape()).unwrap()
1467 }
1468
1469 /// Find the eigenvale of a matrix
1470 ///
1471 /// # Examples
1472 ///
1473 /// ```
1474 /// use sukker::Matrix;
1475 ///
1476 /// let mut matrix = Matrix::init(2.0, (2,100));
1477 ///
1478 /// ```
1479 fn get_eigenvalues(&self) -> Option<Vec<T>> {
1480 todo!()
1481 }
1482
1483 /// Find the eigenvectors
1484 fn get_eigenvectors(&self) -> Option<Vec<T>> {
1485 unimplemented!()
1486 }
1487}
1488
1489/// trait MatrixLinAlg contains all common Linear Algebra functions to be
1490/// performed on matrices
1491impl<'a, T> Matrix<'a, T>
1492where
1493 T: MatrixElement,
1494 <T as FromStr>::Err: Error + 'static,
1495 Vec<T>: IntoParallelIterator,
1496 Vec<&'a T>: IntoParallelRefIterator<'a>,
1497{
1498 /// Adds one matrix to another
1499 ///
1500 /// # Examples
1501 ///
1502 /// ```
1503 /// use sukker::Matrix;
1504 ///
1505 /// let matrix1 = Matrix::init(10.0, (2,2));
1506 /// let matrix2 = Matrix::init(10.0, (2,2));
1507 ///
1508 /// assert_eq!(matrix1.add(&matrix2).unwrap().get(0,0).unwrap(), 20.0);
1509 /// ```
1510 pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
1511 if self.shape() != other.shape() {
1512 return Err(MatrixError::MatrixDimensionMismatchError.into());
1513 }
1514
1515 let data: Vec<T> = self
1516 .data
1517 .iter()
1518 .zip(other.data.iter())
1519 .map(|(&x, &y)| x + y)
1520 .collect();
1521
1522 Ok(Self::new(data, self.shape()).unwrap())
1523 }
1524
1525 /// Subtracts one matrix from another
1526 ///
1527 /// # Examples
1528 ///
1529 /// ```
1530 /// use sukker::Matrix;
1531 ///
1532 /// let matrix1 = Matrix::init(10.0, (2,2));
1533 /// let matrix2 = Matrix::init(10.0, (2,2));
1534 ///
1535 /// assert_eq!(matrix1.sub(&matrix2).unwrap().get(1,0).unwrap(), 0.0);
1536 /// ```
1537 pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
1538 if self.shape() != other.shape() {
1539 return Err(MatrixError::MatrixDimensionMismatchError.into());
1540 }
1541
1542 let data: Vec<T> = self
1543 .data
1544 .iter()
1545 .zip(other.data.iter())
1546 .map(|(&x, &y)| x - y)
1547 .collect();
1548
1549 Ok(Self::new(data, self.shape()).unwrap())
1550 }
1551
1552 /// Subtracts one array from another and returns the absolute value
1553 ///
1554 /// # Examples
1555 ///
1556 /// ```
1557 /// use sukker::Matrix;
1558 ///
1559 /// let matrix1 = Matrix::init(10.0f32, (2,2));
1560 /// let matrix2 = Matrix::init(15.0f32, (2,2));
1561 ///
1562 /// assert_eq!(matrix1.sub_abs(&matrix2).unwrap().get(0,0).unwrap(), 5.0);
1563 /// ```
1564 pub fn sub_abs(&self, other: &Self) -> Result<Self, MatrixError> {
1565 if self.shape() != other.shape() {
1566 return Err(MatrixError::MatrixDimensionMismatchError.into());
1567 }
1568
1569 let data = self
1570 .data
1571 .iter()
1572 .zip(other.data.iter())
1573 .map(|(&x, &y)| if x > y { x - y } else { y - x })
1574 .collect_vec();
1575
1576 Ok(Self::new(data, self.shape()).unwrap())
1577 }
1578
1579 /// Dot product of two matrices
1580 ///
1581 /// # Examples
1582 ///
1583 /// ```
1584 /// use sukker::Matrix;
1585 ///
1586 /// let matrix1 = Matrix::init(20.0, (2,2));
1587 /// let matrix2 = Matrix::init(10.0, (2,2));
1588 ///
1589 /// assert_eq!(matrix1.mul(&matrix2).unwrap().get(0,0).unwrap(), 200.0);
1590 /// ```
1591 pub fn mul(&self, other: &Self) -> Result<Self, MatrixError> {
1592 if self.shape() != other.shape() {
1593 return Err(MatrixError::MatrixDimensionMismatchError.into());
1594 }
1595
1596 let data = self
1597 .data
1598 .iter()
1599 .zip(other.data.iter())
1600 .map(|(&x, &y)| x * y)
1601 .collect_vec();
1602
1603 Ok(Self::new(data, self.shape()).unwrap())
1604 }
1605
1606 /// Dot product of two matrices
1607 ///
1608 /// # Examples
1609 ///
1610 /// ```
1611 /// use sukker::Matrix;
1612 ///
1613 /// let matrix1 = Matrix::init(20.0, (2,2));
1614 /// let matrix2 = Matrix::init(10.0, (2,2));
1615 ///
1616 /// assert_eq!(matrix1.dot(&matrix2).unwrap().get(0,0).unwrap(), 200.0);
1617 /// ```
1618 pub fn dot(&self, other: &Self) -> Result<Self, MatrixError> {
1619 self.mul(other)
1620 }
1621
1622 /// Bad handling of zero div
1623 ///
1624 /// # Examples
1625 ///
1626 /// ```
1627 /// use sukker::Matrix;
1628 ///
1629 /// let matrix1 = Matrix::init(20.0, (2,2));
1630 /// let matrix2 = Matrix::init(10.0, (2,2));
1631 ///
1632 /// assert_eq!(matrix1.div(&matrix2).unwrap().get(0,0).unwrap(), 2.0);
1633 /// ```
1634 pub fn div(&self, other: &Self) -> Result<Self, MatrixError> {
1635 if self.shape() != other.shape() {
1636 return Err(MatrixError::MatrixDimensionMismatchError.into());
1637 }
1638
1639 if other.any(|e| e == &T::zero()) {
1640 return Err(MatrixError::MatrixDivideByZeroError.into());
1641 }
1642
1643 let data = self
1644 .data
1645 .iter()
1646 .zip(other.data.iter())
1647 .map(|(&x, &y)| x / y)
1648 .collect_vec();
1649
1650 Ok(Self::new(data, self.shape()).unwrap())
1651 }
1652
1653 /// Negates every value in the matrix
1654 ///
1655 /// # Examples
1656 ///
1657 /// ```
1658 /// use sukker::{Matrix, LinAlgFloats};
1659 ///
1660 /// let matrix = Matrix::<f32>::ones((2,2));
1661 ///
1662 /// let negated = matrix.neg();
1663 ///
1664 /// assert_eq!(negated.all(|&e| e == -1.0), true);
1665 /// ```
1666 pub fn neg(&self) -> Self {
1667 let data: Vec<T> = self.data.par_iter().map(|&e| e.neg()).collect();
1668
1669 Self::new(data, self.shape()).unwrap()
1670 }
1671
1672 /// Adds a value to a matrix and returns a new matrix
1673 ///
1674 /// # Examples
1675 ///
1676 /// ```
1677 /// use sukker::Matrix;
1678 ///
1679 /// let matrix = Matrix::init(20.0, (2,2));
1680 /// let value: f32 = 2.0;
1681 /// assert_eq!(matrix.add_val(value).get(0,0).unwrap(), 22.0);
1682 /// ```
1683 pub fn add_val(&self, val: T) -> Self {
1684 let data: Vec<T> = self.data.par_iter().map(|&e| e + val).collect();
1685
1686 Self::new(data, self.shape()).unwrap()
1687 }
1688
1689 /// Substracts a value to a matrix and returns a new matrix
1690 ///
1691 /// # Examples
1692 ///
1693 /// ```
1694 /// use sukker::Matrix;
1695 ///
1696 /// let matrix = Matrix::init(20.0, (2,2));
1697 /// let value: f32 = 2.0;
1698 /// assert_eq!(matrix.sub_val(value).get(0,0).unwrap(), 18.0);
1699 /// ```
1700 pub fn sub_val(&self, val: T) -> Self {
1701 let data: Vec<T> = self.data.par_iter().map(|&e| e - val).collect();
1702
1703 Self::new(data, self.shape()).unwrap()
1704 }
1705
1706 /// Multiplies a value to a matrix and returns a new matrix
1707 ///
1708 /// # Examples
1709 ///
1710 /// ```
1711 /// use sukker::Matrix;
1712 ///
1713 /// let matrix = Matrix::init(20.0, (2,2));
1714 /// let value: f32 = 2.0;
1715 /// assert_eq!(matrix.mul_val(value).get(0,0).unwrap(), 40.0);
1716 /// ```
1717 pub fn mul_val(&self, val: T) -> Self {
1718 let data: Vec<T> = self.data.par_iter().map(|&e| e * val).collect();
1719
1720 Self::new(data, self.shape()).unwrap()
1721 }
1722
1723 /// Divides a value to a matrix and returns a new matrix
1724 ///
1725 /// # Examples
1726 ///
1727 /// ```
1728 /// use sukker::Matrix;
1729 ///
1730 /// let matrix = Matrix::init(20.0, (2,2));
1731 /// let value: f32 = 2.0;
1732 ///
1733 /// let result_mat = matrix.div_val(value);
1734 ///
1735 /// assert_eq!(result_mat.get(0,0).unwrap(), 10.0);
1736 /// ```
1737 pub fn div_val(&self, val: T) -> Self {
1738 let data: Vec<T> = self.data.par_iter().map(|&e| e / val).collect();
1739
1740 Self::new(data, self.shape()).unwrap()
1741 }
1742
1743 /// Pows each value in a matrix by val times
1744 ///
1745 /// # Examples
1746 ///
1747 /// ```
1748 /// use sukker::Matrix;
1749 ///
1750 /// let matrix = Matrix::init(2.0, (2,2));
1751 ///
1752 /// let result_mat = matrix.pow(2);
1753 ///
1754 /// assert_eq!(result_mat.get_vec(), vec![4.0, 4.0, 4.0, 4.0]);
1755 /// ```
1756 pub fn pow(&self, val: usize) -> Self {
1757 let data: Vec<T> = self.data.par_iter().map(|&e| pow(e, val)).collect();
1758
1759 Self::new(data, self.shape()).unwrap()
1760 }
1761
1762 /// Takes the absolute values of the matrix
1763 ///
1764 /// # Examples
1765 ///
1766 /// ```
1767 /// use sukker::Matrix;
1768 ///
1769 /// let matrix = Matrix::init(-20.0, (2,2));
1770 ///
1771 /// let res = matrix.abs();
1772 ///
1773 /// assert_eq!(res.all(|&e| e == 20.0), true);
1774 /// ```
1775 pub fn abs(&self) -> Self {
1776 let data: Vec<T> = self.data.par_iter().map(|&e| e.abs()).collect();
1777
1778 Self::new(data, self.shape()).unwrap()
1779 }
1780
1781 /// Multiply a matrix with itself n number of times.
1782 /// This is done by performing a matrix multiplication
1783 /// several time on self and the result of mat.exp(i-1).
1784 ///
1785 /// If matrix is not in form NxN, this function returns None
1786 ///
1787 /// Examples
1788 ///
1789 /// ```
1790 /// use sukker::Matrix;
1791 ///
1792 /// let a = Matrix::<i32>::init(2, (2,2));
1793 ///
1794 /// let res = a.exp(3).unwrap();
1795 ///
1796 /// assert_eq!(res.all(|&e| e == 32), true);
1797 /// ```
1798 pub fn exp(&self, n: usize) -> Option<Self> {
1799 if self.nrows != self.ncols {
1800 return None;
1801 }
1802
1803 let mut res = self.clone();
1804
1805 (0..n - 1).for_each(|_| res = res.matmul(self).unwrap());
1806
1807 Some(res)
1808 }
1809
1810 /// Adds a matrix in-place to a matrix
1811 ///
1812 /// # Examples
1813 ///
1814 /// ```
1815 /// use sukker::Matrix;
1816 ///
1817 /// let mut matrix1 = Matrix::init(20.0, (2,2));
1818 /// let matrix2 = Matrix::init(2.0, (2,2));
1819 ///
1820 /// matrix1.add_self(&matrix2);
1821 ///
1822 /// assert_eq!(matrix1.get(0,0).unwrap(), 22.0);
1823 /// ```
1824 pub fn add_self(&mut self, other: &Self) {
1825 self.data
1826 .par_iter_mut()
1827 .zip(&other.data)
1828 .for_each(|(a, b)| *a += *b);
1829 }
1830
1831 /// Subtracts a matrix in-place to a matrix
1832 ///
1833 /// # Examples
1834 ///
1835 /// ```
1836 /// use sukker::Matrix;
1837 ///
1838 /// let mut matrix1 = Matrix::init(20.0, (2,2));
1839 /// let matrix2 = Matrix::init(2.0, (2,2));
1840 ///
1841 /// matrix1.sub_self(&matrix2);
1842 ///
1843 /// assert_eq!(matrix1.get(0,0).unwrap(), 18.0);
1844 /// ```
1845 pub fn sub_self(&mut self, other: &Self) {
1846 self.data
1847 .par_iter_mut()
1848 .zip(&other.data)
1849 .for_each(|(a, b)| *a -= *b);
1850 }
1851
1852 /// Multiplies a matrix in-place to a matrix
1853 ///
1854 /// # Examples
1855 ///
1856 /// ```
1857 /// use sukker::Matrix;
1858 ///
1859 /// let mut matrix1 = Matrix::init(20.0, (2,2));
1860 /// let matrix2 = Matrix::init(2.0, (2,2));
1861 ///
1862 /// matrix1.mul_self(&matrix2);
1863 ///
1864 /// assert_eq!(matrix1.get(0,0).unwrap(), 40.0);
1865 /// ```
1866 pub fn mul_self(&mut self, other: &Self) {
1867 self.data
1868 .par_iter_mut()
1869 .zip(&other.data)
1870 .for_each(|(a, b)| *a *= *b);
1871 }
1872
1873 /// Divides a matrix in-place to a matrix
1874 ///
1875 /// # Examples
1876 ///
1877 /// ```
1878 /// use sukker::Matrix;
1879 ///
1880 /// let mut matrix1 = Matrix::init(20.0, (2,2));
1881 /// let matrix2 = Matrix::init(2.0, (2,2));
1882 ///
1883 /// matrix1.div_self(&matrix2);
1884 ///
1885 /// assert_eq!(matrix1.get(0,0).unwrap(), 10.0);
1886 /// ```
1887 pub fn div_self(&mut self, other: &Self) {
1888 self.data
1889 .par_iter_mut()
1890 .zip(&other.data)
1891 .for_each(|(a, b)| *a /= *b);
1892 }
1893
1894 /// Abs matrix in-place to a matrix
1895 ///
1896 /// # Examples
1897 ///
1898 /// ```
1899 /// use sukker::Matrix;
1900 ///
1901 /// let mut matrix = Matrix::init(20.0, (2,2));
1902 ///
1903 /// matrix.abs_self()
1904 ///
1905 /// // assert_eq!(matrix1.get(0,0).unwrap(), 22.0);
1906 /// ```
1907 pub fn abs_self(&mut self) {
1908 self.data.par_iter_mut().for_each(|e| *e = abs(*e))
1909 }
1910
1911 /// Adds a value in-place to a matrix
1912 ///
1913 /// # Examples
1914 ///
1915 /// ```
1916 /// use sukker::Matrix;
1917 ///
1918 /// let mut matrix = Matrix::init(20.0, (2,2));
1919 /// let value: f32 = 2.0;
1920 ///
1921 /// matrix.add_val_self(value);
1922 ///
1923 /// assert_eq!(matrix.get(0,0).unwrap(), 22.0);
1924 /// ```
1925 pub fn add_val_self(&mut self, val: T) {
1926 self.data.par_iter_mut().for_each(|e| *e += val);
1927 }
1928
1929 /// Subtracts a value in-place to a matrix
1930 ///
1931 /// # Examples
1932 ///
1933 /// ```
1934 /// use sukker::Matrix;
1935 ///
1936 /// let mut matrix = Matrix::init(20.0, (2,2));
1937 /// let value: f32 = 2.0;
1938 ///
1939 /// matrix.sub_val_self(value);
1940 ///
1941 /// assert_eq!(matrix.get(0,0).unwrap(), 18.0);
1942 /// ```
1943 pub fn sub_val_self(&mut self, val: T) {
1944 self.data.par_iter_mut().for_each(|e| *e -= val);
1945 }
1946
1947 /// Mults a value in-place to a matrix
1948 ///
1949 /// # Examples
1950 ///
1951 /// ```
1952 /// use sukker::Matrix;
1953 ///
1954 /// let mut matrix = Matrix::init(20.0, (2,2));
1955 /// let value: f32 = 2.0;
1956 ///
1957 /// matrix.mul_val_self(value);
1958 ///
1959 /// assert_eq!(matrix.get(0,0).unwrap(), 40.0);
1960 /// ```
1961 pub fn mul_val_self(&mut self, val: T) {
1962 self.data.par_iter_mut().for_each(|e| *e *= val);
1963 }
1964
1965 /// Divs a value in-place to a matrix
1966 ///
1967 /// # Examples
1968 ///
1969 /// ```
1970 /// use sukker::Matrix;
1971 ///
1972 /// let mut matrix = Matrix::init(20.0, (2,2));
1973 /// let value: f32 = 2.0;
1974 ///
1975 /// matrix.div_val_self(value);
1976 ///
1977 /// assert_eq!(matrix.get(0,0).unwrap(), 10.0);
1978 /// ```
1979 pub fn div_val_self(&mut self, val: T) {
1980 self.data.par_iter_mut().for_each(|e| *e /= val);
1981 }
1982
1983 /// Transposed matrix multiplications
1984 ///
1985 /// # Examples
1986 ///
1987 /// ```
1988 /// use sukker::Matrix;
1989 ///
1990 /// let mut matrix1 = Matrix::init(2.0, (2,4));
1991 /// let matrix2 = Matrix::init(2.0, (4,2));
1992 ///
1993 /// let result = matrix1.matmul(&matrix2).unwrap();
1994 ///
1995 /// assert_eq!(result.get(0,0).unwrap(), 16.0);
1996 /// assert_eq!(result.shape(), (2,2));
1997 /// ```
1998 pub fn matmul(&self, other: &Self) -> Result<Self, MatrixError> {
1999 // assert M N x N P
2000 if self.ncols != other.nrows {
2001 return Err(MatrixError::MatrixDimensionMismatchError.into());
2002 }
2003
2004 Ok(self.matmul_helper(other))
2005 }
2006
2007 /// Shorthand method for matmul
2008 pub fn mm(&self, other: &Self) -> Result<Self, MatrixError> {
2009 self.matmul(other)
2010 }
2011
2012 /// Get's the determinat of a N x N matrix
2013 ///
2014 /// Examples
2015 ///
2016 /// ```
2017 /// use sukker::Matrix;
2018 ///
2019 /// let mat: Matrix<i32> = Matrix::new(vec![1,3,5,9,1,3,1,7,4,3,9,7,5,2,0,9], (4,4)).unwrap();
2020 ///
2021 ///
2022 /// let res = mat.determinant().unwrap();
2023 ///
2024 /// assert_eq!(res, -376);
2025 /// ```
2026 pub fn determinant(&self) -> Option<T> {
2027 if self.nrows != self.ncols {
2028 return None;
2029 }
2030
2031 Some(self.determinant_helper())
2032 }
2033
2034 /// Shorthand call for `determinant`
2035 pub fn det(&self) -> Option<T> {
2036 self.determinant()
2037 }
2038
2039 /// Finds the inverse of a matrix if possible
2040 ///
2041 /// Definition: AA^-1 = A^-1A = I
2042 ///
2043 /// Examples
2044 ///
2045 /// ```
2046 /// use sukker::Matrix;
2047 ///
2048 /// let matrix = Matrix::new(vec![4,7,2,6], (2,2)).unwrap();
2049 ///
2050 /// // let inverse = matrix.inverse();
2051 ///
2052 /// ```
2053 pub fn inverse(&self) -> Option<Self> {
2054 if self.shape() != (2, 2) {
2055 eprintln!("Function not implemented for inverse on larger matrices yet!");
2056 return None;
2057 }
2058 if self.nrows != self.ncols {
2059 eprintln!("Oops");
2060 return None;
2061 }
2062
2063 if self.determinant().unwrap() == T::zero() {
2064 return None;
2065 }
2066
2067 let a = self.at(0, 0);
2068 let b = self.at(0, 1);
2069 let c = self.at(1, 0);
2070 let d = self.at(1, 1);
2071
2072 let mut mat = Self::new(vec![d, -b, -c, a], self.shape()).unwrap();
2073
2074 mat.mul_val_self(T::one() / (a * d - b * c));
2075
2076 return Some(mat);
2077
2078 // let mut inverse = Self::zeros_like(self);
2079 //
2080 // let identity_mat = Self::eye_like(self);
2081 //
2082 // Some(inverse)
2083 }
2084
2085 /// Transpose a matrix in-place
2086 ///
2087 /// # Examples
2088 ///
2089 /// ```
2090 /// use sukker::Matrix;
2091 ///
2092 /// let mut matrix = Matrix::init(2.0, (2,100));
2093 /// matrix.transpose();
2094 ///
2095 /// assert_eq!(matrix.shape(), (100,2));
2096 /// ```
2097 pub fn transpose(&mut self) {
2098 for i in 0..self.nrows {
2099 for j in (i + 1)..self.ncols {
2100 let lhs = at!(i, j, self.ncols);
2101 let rhs = at!(j, i, self.nrows);
2102 self.data.swap(lhs, rhs);
2103 }
2104 }
2105
2106 swap(&mut self.nrows, &mut self.ncols);
2107 }
2108
2109 /// Shorthand call for transpose
2110 ///
2111 /// # Examples
2112 ///
2113 /// ```
2114 /// use sukker::Matrix;
2115 ///
2116 /// let mut matrix = Matrix::init(2.0, (2,100));
2117 /// matrix.t();
2118 ///
2119 /// assert_eq!(matrix.shape(), (100,2));
2120 /// ```
2121 pub fn t(&mut self) {
2122 self.transpose()
2123 }
2124
2125 /// Transpose a matrix and return a copy
2126 ///
2127 /// # Examples
2128 ///
2129 /// ```
2130 /// use sukker::Matrix;
2131 ///
2132 /// let matrix = Matrix::init(2.0, (2,100));
2133 /// let result = matrix.transpose_copy();
2134 ///
2135 /// assert_eq!(result.shape(), (100,2));
2136 /// ```
2137 pub fn transpose_copy(&self) -> Self {
2138 let mut res = self.clone();
2139 res.transpose();
2140 res
2141 }
2142}
2143
2144/// Implementations for predicates
2145impl<'a, T> Matrix<'a, T>
2146where
2147 T: MatrixElement,
2148 <T as FromStr>::Err: Error + 'static,
2149 Vec<T>: IntoParallelIterator,
2150 Vec<&'a T>: IntoParallelRefIterator<'a>,
2151{
2152 /// Counts all occurances where predicate holds
2153 ///
2154 /// # Examples
2155 ///
2156 /// ```
2157 /// use sukker::Matrix;
2158 ///
2159 /// let matrix = Matrix::init(2.0f32, (2,4));
2160 ///
2161 /// assert_eq!(matrix.count_where(|&e| e == 2.0), 8);
2162 /// ```
2163 pub fn count_where<F>(&'a self, pred: F) -> usize
2164 where
2165 F: Fn(&T) -> bool + Sync,
2166 {
2167 self.data.par_iter().filter(|&e| pred(e)).count()
2168 }
2169
2170 /// Sums all occurances where predicate holds
2171 ///
2172 /// # Examples
2173 ///
2174 /// ```
2175 /// use sukker::Matrix;
2176 ///
2177 /// let matrix = Matrix::init(2.0, (2,4));
2178 ///
2179 /// assert_eq!(matrix.sum_where(|&e| e == 2.0), 16.0);
2180 /// ```
2181 pub fn sum_where<F>(&self, pred: F) -> T
2182 where
2183 F: Fn(&T) -> bool + Sync,
2184 {
2185 self.data
2186 .par_iter()
2187 .filter(|&e| pred(e))
2188 .copied()
2189 .sum::<T>()
2190 }
2191
2192 /// Setsall elements where predicate holds true.
2193 /// The new value is to be set inside the predicate as well
2194 ///
2195 /// # Examples
2196 ///
2197 /// ```
2198 /// use sukker::Matrix;
2199 ///
2200 /// let mut matrix = Matrix::init(2.0, (2,4));
2201 ///
2202 /// assert_eq!(matrix.get(0,0).unwrap(), 2.0);
2203 ///
2204 /// matrix.set_where(|e| {
2205 /// if *e == 2.0 {
2206 /// *e = 2.3;
2207 /// }
2208 /// });
2209 ///
2210 /// assert_eq!(matrix.get(0,0).unwrap(), 2.3);
2211 /// ```
2212 pub fn set_where<P>(&mut self, mut pred: P)
2213 where
2214 P: FnMut(&mut T) + Sync + Send,
2215 {
2216 self.data.iter_mut().for_each(|e| pred(e));
2217 }
2218
2219 /// Return whether or not a predicate holds at least once
2220 ///
2221 /// # Examples
2222 ///
2223 /// ```
2224 /// use sukker::Matrix;
2225 ///
2226 /// let matrix = Matrix::init(2.0, (2,4));
2227 ///
2228 /// assert_eq!(matrix.any(|&e| e == 2.0), true);
2229 /// ```
2230 pub fn any<F>(&self, pred: F) -> bool
2231 where
2232 F: Fn(&T) -> bool + Sync + Send,
2233 {
2234 self.data.par_iter().any(pred)
2235 }
2236
2237 /// Returns whether or not predicate holds for all values
2238 ///
2239 /// # Examples
2240 ///
2241 /// ```
2242 /// use sukker::Matrix;
2243 ///
2244 /// let matrix = Matrix::randomize_range(1.0, 4.0, (2,4));
2245 ///
2246 /// assert_eq!(matrix.all(|&e| e >= 1.0), true);
2247 /// ```
2248 pub fn all<F>(&self, pred: F) -> bool
2249 where
2250 F: Fn(&T) -> bool + Sync + Send,
2251 {
2252 self.data.par_iter().all(pred)
2253 }
2254
2255 /// Finds first index where predicates holds if possible
2256 ///
2257 /// # Examples
2258 ///
2259 /// ```
2260 /// use sukker::Matrix;
2261 ///
2262 /// let matrix = Matrix::init(2f32, (2,4));
2263 ///
2264 /// assert_eq!(matrix.find(|&e| e >= 1f32), Some((0,0)));
2265 /// ```
2266 pub fn find<F>(&self, pred: F) -> Option<Shape>
2267 where
2268 F: Fn(&T) -> bool + Sync,
2269 {
2270 if let Some((idx, _)) = self.data.iter().find_position(|&e| pred(e)) {
2271 return Some(self.one_to_2d_idx(idx));
2272 }
2273
2274 None
2275 }
2276
2277 /// Finds all indeces where predicates holds if possible
2278 ///
2279 /// # Examples
2280 ///
2281 /// ```
2282 /// use sukker::Matrix;
2283 ///
2284 /// let matrix = Matrix::init(2.0, (2,4));
2285 ///
2286 /// assert_eq!(matrix.find_all(|&e| e >= 3.0), None);
2287 /// ```
2288 pub fn find_all<F>(&self, pred: F) -> Option<Vec<Shape>>
2289 where
2290 F: Fn(&T) -> bool + Sync,
2291 {
2292 let data: Vec<Shape> = self
2293 .data
2294 .par_iter()
2295 .enumerate()
2296 .filter_map(|(idx, elem)| {
2297 if pred(elem) {
2298 Some(self.one_to_2d_idx(idx))
2299 } else {
2300 None
2301 }
2302 })
2303 .collect();
2304
2305 if data.is_empty() {
2306 None
2307 } else {
2308 Some(data)
2309 }
2310 }
2311}