easy_ml/matrices/mod.rs
1/*!
2 * Generic matrix type.
3 *
4 * Matrices are generic over some type `T`. If `T` is [Numeric](super::numeric) then
5 * the matrix can be used in a mathematical way.
6 */
7
8#[cfg(feature = "serde")]
9use serde::Serialize;
10
11mod errors;
12pub mod iterators;
13pub mod operations;
14pub mod slices;
15pub mod views;
16
17pub use errors::ScalarConversionError;
18
19use crate::linear_algebra;
20use crate::matrices::iterators::*;
21use crate::matrices::slices::Slice2D;
22use crate::matrices::views::{
23 IndexRange, MatrixMask, MatrixPart, MatrixQuadrants, MatrixRange, MatrixReverse, MatrixView,
24 Reverse,
25};
26use crate::numeric::extra::{Real, RealRef};
27use crate::numeric::{Numeric, NumericRef};
28
29/**
30 * A general purpose matrix of some type. This type may implement
31 * no traits, in which case the matrix will be rather useless. If the
32 * type implements [`Clone`] most storage and accessor methods are defined and if the type
33 * implements [`Numeric`](super::numeric) then the matrix can be used in
34 * a mathematical way.
35 *
36 * When doing numeric operations with Matrices you should be careful to not
37 * consume a matrix by accidentally using it by value. All the operations are
38 * also defined on references to matrices so you should favor `&x * &y` style
39 * notation for matrices you intend to continue using. There are also convenience
40 * operations defined for a matrix and a scalar.
41 *
42 * # Matrix size invariants
43 *
44 * Matrices must always be at least 1x1. You cannot construct a matrix with no rows or
45 * no columns, and any function that resizes matrices will error if you try to use it
46 * in a way that would construct a 0x1, 1x0, or 0x0 matrix. The maximum size of a matrix
47 * is dependent on the platform's `std::isize::MAX` value. Matrices with dimensions NxM
48 * such that N * M < `std::isize::MAX` should not cause any errors in this library, but
49 * attempting to expand their size further may cause panics and or errors. At the time of
50 * writing it is no longer possible to construct or use matrices where the product of their
51 * number of rows and columns exceed `std::isize::MAX`, but some constructor methods may be used
52 * to attempt this. Concerned readers should note that on a 64 bit computer this maximum
53 * value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
54 *
55 * # Matrix layout and iterator performance
56 *
57 * [See iterators submodule for Matrix layout and iterator performance](iterators#matrix-layout-and-iterator-performance)
58 *
59 * # Matrix operations
60 *
61 * [See operations submodule](operations)
62 */
63#[derive(Debug)]
64#[cfg_attr(feature = "serde", derive(Serialize))]
65pub struct Matrix<T> {
66 data: Vec<T>,
67 rows: Row,
68 columns: Column,
69}
70
71/// The maximum row and column lengths are usize, due to the internal storage being backed by Vec
72pub type Row = usize;
73/// The maximum row and column lengths are usize, due to the internal storage being backed by Vec
74pub type Column = usize;
75
76/**
77 * Methods for matrices of any type, including non numerical types such as bool.
78 */
79impl<T> Matrix<T> {
80 /**
81 * Creates a 1x1 matrix from some scalar
82 */
83 pub fn from_scalar(value: T) -> Matrix<T> {
84 Matrix {
85 data: vec![value],
86 rows: 1,
87 columns: 1,
88 }
89 }
90
91 /**
92 * Creates a row vector (1xN) from a list
93 *
94 * # Panics
95 *
96 * Panics if no values are provided. Note: this method erroneously did not validate its inputs
97 * in Easy ML versions up to and including 1.7.0
98 */
99 #[track_caller]
100 pub fn row(values: Vec<T>) -> Matrix<T> {
101 assert!(!values.is_empty(), "No values provided");
102 Matrix {
103 columns: values.len(),
104 data: values,
105 rows: 1,
106 }
107 }
108
109 /**
110 * Creates a column vector (Nx1) from a list
111 *
112 * # Panics
113 *
114 * Panics if no values are provided. Note: this method erroneously did not validate its inputs
115 * in Easy ML versions up to and including 1.7.0
116 */
117 #[track_caller]
118 pub fn column(values: Vec<T>) -> Matrix<T> {
119 assert!(!values.is_empty(), "No values provided");
120 Matrix {
121 rows: values.len(),
122 data: values,
123 columns: 1,
124 }
125 }
126
127 /**
128 * Creates a matrix from a nested array of values, each inner vector
129 * being a row, and hence the outer vector containing all rows in sequence, the
130 * same way as when writing matrices in mathematics.
131 *
132 * Example of a 2 x 3 matrix in both notations:
133 * ```ignore
134 * [
135 * 1, 2, 4
136 * 8, 9, 3
137 * ]
138 * ```
139 * ```
140 * use easy_ml::matrices::Matrix;
141 * Matrix::from(vec![
142 * vec![ 1, 2, 4 ],
143 * vec![ 8, 9, 3 ]
144 * ]);
145 * ```
146 *
147 * # Panics
148 *
149 * Panics if the input is jagged or rows or column length is 0.
150 */
151 #[track_caller]
152 pub fn from(mut values: Vec<Vec<T>>) -> Matrix<T> {
153 assert!(!values.is_empty(), "No rows defined");
154 // check length of first row is > 1
155 assert!(!values[0].is_empty(), "No column defined");
156 // check length of each row is the same
157 assert!(
158 values.iter().map(|x| x.len()).all(|x| x == values[0].len()),
159 "Inconsistent size"
160 );
161 // flatten the data into a row major layout
162 let rows = values.len();
163 let columns = values[0].len();
164 let mut data = Vec::with_capacity(rows * columns);
165 let mut value_stream = values.drain(..);
166 for _ in 0..rows {
167 let mut value_row_stream = value_stream.next().unwrap();
168 let mut row_of_values = value_row_stream.drain(..);
169 for _ in 0..columns {
170 data.push(row_of_values.next().unwrap());
171 }
172 }
173 Matrix {
174 data,
175 rows,
176 columns,
177 }
178 }
179
180 /**
181 * Creates a matrix with the specified size from a row major vec of data.
182 * The length of the vec must match the size of the matrix or the constructor
183 * will panic.
184 *
185 * Example of a 2 x 3 matrix in both notations:
186 * ```ignore
187 * [
188 * 1, 2, 4
189 * 8, 9, 3
190 * ]
191 * ```
192 * ```
193 * use easy_ml::matrices::Matrix;
194 * Matrix::from_flat_row_major((2, 3), vec![
195 * 1, 2, 4,
196 * 8, 9, 3
197 * ]);
198 * ```
199 *
200 * This method is more efficient than [`Matrix::from`](Matrix::from())
201 * but requires specifying the size explicitly and manually keeping track of where rows
202 * start and stop.
203 *
204 * # Panics
205 *
206 * Panics if the length of the vec does not match the size of the matrix, or no values are
207 * provided. Note: this method erroneously did not validate its inputs were not empty in
208 * Easy ML versions up to and including 1.7.0
209 */
210 #[track_caller]
211 pub fn from_flat_row_major(size: (Row, Column), values: Vec<T>) -> Matrix<T> {
212 assert!(
213 size.0 * size.1 == values.len(),
214 "Inconsistent size, attempted to construct a {}x{} matrix but provided with {} elements.",
215 size.0,
216 size.1,
217 values.len()
218 );
219 assert!(!values.is_empty(), "No values provided");
220 Matrix {
221 data: values,
222 rows: size.0,
223 columns: size.1,
224 }
225 }
226
227 /**
228 * Creates a matrix with the specified size initalised from a function.
229 *
230 * ```
231 * use easy_ml::matrices::Matrix;
232 * let matrix = Matrix::from_fn((4, 4), |(r, c)| r * c);
233 * assert_eq!(
234 * matrix,
235 * Matrix::from(vec![
236 * vec![ 0, 0, 0, 0 ],
237 * vec![ 0, 1, 2, 3 ],
238 * vec![ 0, 2, 4, 6 ],
239 * vec![ 0, 3, 6, 9 ],
240 * ])
241 * );
242 * ```
243 *
244 * # Panics
245 *
246 * Panics if the size has 0 rows or columns.
247 */
248 #[track_caller]
249 pub fn from_fn<F>(size: (Row, Column), mut producer: F) -> Matrix<T>
250 where
251 F: FnMut((Row, Column)) -> T,
252 {
253 use crate::tensors::indexing::ShapeIterator;
254 let length = size.0 * size.1;
255 let mut data = Vec::with_capacity(length);
256 let iterator = ShapeIterator::from([("row", size.0), ("column", size.1)]);
257 for [r, c] in iterator {
258 data.push(producer((r, c)));
259 }
260 Matrix::from_flat_row_major(size, data)
261 }
262
263 #[deprecated(
264 since = "1.1.0",
265 note = "Incorrect use of terminology, a unit matrix is another term for an identity matrix, please use `from_scalar` instead"
266 )]
267 pub fn unit(value: T) -> Matrix<T> {
268 Matrix::from_scalar(value)
269 }
270
271 /**
272 * Returns the dimensionality of this matrix in Row, Column format
273 */
274 pub fn size(&self) -> (Row, Column) {
275 (self.rows, self.columns)
276 }
277
278 /**
279 * Gets the number of rows in this matrix.
280 */
281 pub fn rows(&self) -> Row {
282 self.rows
283 }
284
285 /**
286 * Gets the number of columns in this matrix.
287 */
288 pub fn columns(&self) -> Column {
289 self.columns
290 }
291
292 /**
293 * Matrix data is stored as row major, so each row is stored as
294 * adjacent items going through the different columns. Therefore,
295 * to index this flattened representation we jump down in row sized
296 * blocks to reach the correct row, and then jump further equal to
297 * the column. The confusing thing is that the number of columns
298 * this matrix has is the length of each of the rows in this matrix,
299 * and vice versa.
300 */
301 fn get_index(&self, row: Row, column: Column) -> usize {
302 column + (row * self.columns())
303 }
304
305 /**
306 * The reverse of [get_index], converts from the flattened storage
307 * in memory into the row and column to index at this position.
308 *
309 * Matrix data is stored as row major, so each multiple of the number
310 * of columns starts a new row, and each index modulo the columns
311 * gives the column.
312 */
313 #[allow(dead_code)]
314 fn get_row_column(&self, index: usize) -> (Row, Column) {
315 (index / self.columns(), index % self.columns())
316 }
317
318 /**
319 * Gets a reference to the value at this row and column. Rows and Columns are 0 indexed.
320 *
321 * # Panics
322 *
323 * Panics if the index is out of range.
324 */
325 #[track_caller]
326 pub fn get_reference(&self, row: Row, column: Column) -> &T {
327 assert!(row < self.rows(), "Row out of index");
328 assert!(column < self.columns(), "Column out of index");
329 &self.data[self.get_index(row, column)]
330 }
331
332 /**
333 * Gets a mutable reference to the value at this row and column.
334 * Rows and Columns are 0 indexed.
335 *
336 * # Panics
337 *
338 * Panics if the index is out of range.
339 */
340 #[track_caller]
341 pub fn get_reference_mut(&mut self, row: Row, column: Column) -> &mut T {
342 assert!(row < self.rows(), "Row out of index");
343 assert!(column < self.columns(), "Column out of index");
344 let index = self.get_index(row, column);
345 // borrow for get_index ends
346 &mut self.data[index]
347 }
348
349 /**
350 * Not public API because don't want to name clash with the method on MatrixRef
351 * that calls this.
352 */
353 pub(crate) fn _try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
354 if row < self.rows() && column < self.columns() {
355 Some(&self.data[self.get_index(row, column)])
356 } else {
357 None
358 }
359 }
360
361 /**
362 * Not public API because don't want to name clash with the method on MatrixRef
363 * that calls this.
364 */
365 pub(crate) unsafe fn _get_reference_unchecked(&self, row: Row, column: Column) -> &T {
366 unsafe { self.data.get_unchecked(self.get_index(row, column)) }
367 }
368
369 /**
370 * Sets a new value to this row and column. Rows and Columns are 0 indexed.
371 *
372 * # Panics
373 *
374 * Panics if the index is out of range.
375 */
376 #[track_caller]
377 pub fn set(&mut self, row: Row, column: Column, value: T) {
378 assert!(row < self.rows(), "Row out of index");
379 assert!(column < self.columns(), "Column out of index");
380 let index = self.get_index(row, column);
381 // borrow for get_index ends
382 self.data[index] = value;
383 }
384
385 /**
386 * Not public API because don't want to name clash with the method on MatrixMut
387 * that calls this.
388 */
389 pub(crate) fn _try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
390 if row < self.rows() && column < self.columns() {
391 let index = self.get_index(row, column);
392 // borrow for get_index ends
393 Some(&mut self.data[index])
394 } else {
395 None
396 }
397 }
398
399 /**
400 * Not public API because don't want to name clash with the method on MatrixMut
401 * that calls this.
402 */
403 pub(crate) unsafe fn _get_reference_unchecked_mut(
404 &mut self,
405 row: Row,
406 column: Column,
407 ) -> &mut T {
408 unsafe {
409 let index = self.get_index(row, column);
410 // borrow for get_index ends
411 self.data.get_unchecked_mut(index)
412 }
413 }
414
415 /**
416 * Removes a row from this Matrix, shifting all other rows to the left.
417 * Rows are 0 indexed.
418 *
419 * # Panics
420 *
421 * This will panic if the row does not exist or the matrix only has one row.
422 */
423 #[track_caller]
424 pub fn remove_row(&mut self, row: Row) {
425 assert!(self.rows() > 1);
426 let mut r = 0;
427 let mut c = 0;
428 // drop the values at the specified row
429 let columns = self.columns();
430 self.data.retain(|_| {
431 let keep = r != row;
432 if c < (columns - 1) {
433 c += 1;
434 } else {
435 r += 1;
436 c = 0;
437 }
438 keep
439 });
440 self.rows -= 1;
441 }
442
443 /**
444 * Removes a column from this Matrix, shifting all other columns to the left.
445 * Columns are 0 indexed.
446 *
447 * # Panics
448 *
449 * This will panic if the column does not exist or the matrix only has one column.
450 */
451 #[track_caller]
452 pub fn remove_column(&mut self, column: Column) {
453 assert!(self.columns() > 1);
454 let mut r = 0;
455 let mut c = 0;
456 // drop the values at the specified column
457 let columns = self.columns();
458 self.data.retain(|_| {
459 let keep = c != column;
460 if c < (columns - 1) {
461 c += 1;
462 } else {
463 r += 1;
464 c = 0;
465 }
466 keep
467 });
468 self.columns -= 1;
469 }
470
471 /**
472 * Returns an iterator over references to a column vector in this matrix.
473 * Columns are 0 indexed.
474 *
475 * # Panics
476 *
477 * Panics if the column does not exist in this matrix.
478 */
479 #[track_caller]
480 pub fn column_reference_iter(&self, column: Column) -> ColumnReferenceIterator<T> {
481 ColumnReferenceIterator::new(self, column)
482 }
483
484 /**
485 * Returns an iterator over references to a row vector in this matrix.
486 * Rows are 0 indexed.
487 *
488 * # Panics
489 *
490 * Panics if the row does not exist in this matrix.
491 */
492 #[track_caller]
493 pub fn row_reference_iter(&self, row: Row) -> RowReferenceIterator<T> {
494 RowReferenceIterator::new(self, row)
495 }
496
497 /**
498 * Returns an iterator over mutable references to a column vector in this matrix.
499 * Columns are 0 indexed.
500 *
501 * # Panics
502 *
503 * Panics if the column does not exist in this matrix.
504 */
505 #[track_caller]
506 pub fn column_reference_mut_iter(&mut self, column: Column) -> ColumnReferenceMutIterator<T> {
507 ColumnReferenceMutIterator::new(self, column)
508 }
509
510 /**
511 * Returns an iterator over mutable references to a row vector in this matrix.
512 * Rows are 0 indexed.
513 *
514 * # Panics
515 *
516 * Panics if the row does not exist in this matrix.
517 */
518 #[track_caller]
519 pub fn row_reference_mut_iter(&mut self, row: Row) -> RowReferenceMutIterator<T> {
520 RowReferenceMutIterator::new(self, row)
521 }
522
523 /**
524 * Returns a column major iterator over references to all values in this matrix,
525 * proceeding through each column in order.
526 */
527 pub fn column_major_reference_iter(&self) -> ColumnMajorReferenceIterator<T> {
528 ColumnMajorReferenceIterator::new(self)
529 }
530
531 /**
532 * Returns a row major iterator over references to all values in this matrix,
533 * proceeding through each row in order.
534 */
535 pub fn row_major_reference_iter(&self) -> RowMajorReferenceIterator<T> {
536 RowMajorReferenceIterator::new(self)
537 }
538
539 // Non public row major reference iterator since we don't want to expose our implementation
540 // details to public API since then we could never change them.
541 pub(crate) fn direct_row_major_reference_iter(&self) -> std::slice::Iter<T> {
542 self.data.iter()
543 }
544
545 // Non public row major reference iterator since we don't want to expose our implementation
546 // details to public API since then we could never change them.
547 pub(crate) fn direct_row_major_reference_iter_mut(&mut self) -> std::slice::IterMut<T> {
548 self.data.iter_mut()
549 }
550
551 /**
552 * Returns a column major iterator over mutable references to all values in this matrix,
553 * proceeding through each column in order.
554 */
555 pub fn column_major_reference_mut_iter(&mut self) -> ColumnMajorReferenceMutIterator<T> {
556 ColumnMajorReferenceMutIterator::new(self)
557 }
558
559 /**
560 * Returns a row major iterator over mutable references to all values in this matrix,
561 * proceeding through each row in order.
562 */
563 pub fn row_major_reference_mut_iter(&mut self) -> RowMajorReferenceMutIterator<T> {
564 RowMajorReferenceMutIterator::new(self)
565 }
566
567 /**
568 * Creates a column major iterator over all values in this matrix,
569 * proceeding through each column in order.
570 */
571 pub fn column_major_owned_iter(self) -> ColumnMajorOwnedIterator<T>
572 where
573 T: Default,
574 {
575 ColumnMajorOwnedIterator::new(self)
576 }
577
578 /**
579 * Creates a row major iterator over all values in this matrix,
580 * proceeding through each row in order.
581 */
582 pub fn row_major_owned_iter(self) -> RowMajorOwnedIterator<T>
583 where
584 T: Default,
585 {
586 RowMajorOwnedIterator::new(self)
587 }
588
589 /**
590 * Returns an iterator over references to the main diagonal in this matrix.
591 */
592 pub fn diagonal_reference_iter(&self) -> DiagonalReferenceIterator<T> {
593 DiagonalReferenceIterator::new(self)
594 }
595
596 /**
597 * Returns an iterator over mutable references to the main diagonal in this matrix.
598 */
599 pub fn diagonal_reference_mut_iter(&mut self) -> DiagonalReferenceMutIterator<T> {
600 DiagonalReferenceMutIterator::new(self)
601 }
602
603 /**
604 * Shrinks this matrix down from its current MxN size down to
605 * some new size OxP where O and P are determined by the kind of
606 * slice given and 1 <= O <= M and 1 <= P <= N.
607 *
608 * Only rows and columns specified by the slice will be retained, so for
609 * instance if the Slice is constructed by
610 * `Slice2D::new().rows(Slice::Range(0..2)).columns(Slice::Range(0..3))` then the
611 * modified matrix will be no bigger than 2x3 and contain up to the first two
612 * rows and first three columns that it previously had.
613 *
614 * See [Slice](slices::Slice) for constructing slices.
615 *
616 * # Panics
617 *
618 * This function will panic if the slice would delete all rows or all columns
619 * from this matrix, ie the resulting matrix must be at least 1x1.
620 */
621 #[track_caller]
622 pub fn retain_mut(&mut self, slice: Slice2D) {
623 let mut r = 0;
624 let mut c = 0;
625 // drop the values rejected by the slice
626 let columns = self.columns();
627 self.data.retain(|_| {
628 let keep = slice.accepts(r, c);
629 if c < (columns - 1) {
630 c += 1;
631 } else {
632 r += 1;
633 c = 0;
634 }
635 keep
636 });
637 // work out the resulting size of this matrix by using the non
638 // public fields of the Slice2D to handle each row and column
639 // seperately.
640 let remaining_rows = {
641 let mut accepted = 0;
642 for i in 0..self.rows() {
643 if slice.rows.accepts(i) {
644 accepted += 1;
645 }
646 }
647 accepted
648 };
649 let remaining_columns = {
650 let mut accepted = 0;
651 for i in 0..self.columns() {
652 if slice.columns.accepts(i) {
653 accepted += 1;
654 }
655 }
656 accepted
657 };
658 assert!(
659 remaining_rows > 0,
660 "Provided slice must leave at least 1 row in the retained matrix"
661 );
662 assert!(
663 remaining_columns > 0,
664 "Provided slice must leave at least 1 column in the retained matrix"
665 );
666 assert!(
667 !self.data.is_empty(),
668 "Provided slice must leave at least 1 row and 1 column in the retained matrix"
669 );
670 self.rows = remaining_rows;
671 self.columns = remaining_columns
672 // By construction jagged slices should be impossible, if this
673 // invariant later changes by accident it would be possible to break the
674 // rectangle shape invariant on a matrix object
675 // As Slice2D should prevent the construction of jagged slices no
676 // check is here to detect if all rows are still the same length
677 }
678
679 /**
680 * Consumes a 1x1 matrix and converts it into a scalar without copying the data.
681 *
682 * # Example
683 *
684 * ```
685 * use easy_ml::matrices::Matrix;
686 * # fn main() -> Result<(), Box<dyn std::error::Error>> {
687 * let x = Matrix::column(vec![ 1.0, 2.0, 3.0 ]);
688 * let sum_of_squares: f64 = (x.transpose() * x).try_into_scalar()?;
689 * # Ok(())
690 * # }
691 * ```
692 */
693 pub fn try_into_scalar(self) -> Result<T, ScalarConversionError> {
694 if self.size() == (1, 1) {
695 Ok(self.data.into_iter().next().unwrap())
696 } else {
697 Err(ScalarConversionError {})
698 }
699 }
700
701 /**
702 * Partition a matrix into an arbitary number of non overlapping parts.
703 *
704 * **This function is much like a hammer you should be careful to not overuse. If you don't need
705 * to mutate the parts of the matrix data individually it will be much easier and less error
706 * prone to create immutable views into the matrix using [MatrixRange] instead.**
707 *
708 * Parts are returned in row major order, forming a grid of slices into the Matrix data that
709 * can be mutated independently.
710 *
711 * # Panics
712 *
713 * Panics if any row or column index is greater than the number of rows or columns in the
714 * matrix. Each list of row partitions and column partitions must also be in ascending order.
715 *
716 * # Further Info
717 *
718 * The partitions form the boundries between each slice of matrix data. Hence, for each
719 * dimension, each partition may range between 0 and the length of the dimension inclusive.
720 *
721 * For one dimension of length 5, you can supply 0 up to 6 partitions,
722 * `[0,1,2,3,4,5]` would split that dimension into 7, 0 to 0, 0 to 1, 1 to 2,
723 * 2 to 3, 3 to 4, 4 to 5 and 5 to 5. 0 to 0 and 5 to 5 would of course be empty and the
724 * 5 parts in between would each be of length 1 along that dimension.
725 * `[2,4]` would instead split that dimension into three parts of 0 to 2, 2 to 4, and 4 to 5.
726 * `[]` would not split that dimension at all, and give a single part of 0 to 5.
727 *
728 * `partition` does this along both dimensions, and returns the parts in row major order, so
729 * you will receive a list of R+1 * C+1 length where R is the length of the row partitions
730 * provided and C is the length of the column partitions provided. If you just want to split
731 * a matrix into a 2x2 grid see [`partition_quadrants`](Matrix::partition_quadrants) which
732 * provides a dedicated API with more ergonomics for extracting the parts.
733 */
734 #[track_caller]
735 pub fn partition(
736 &mut self,
737 row_partitions: &[Row],
738 column_partitions: &[Column],
739 ) -> Vec<MatrixView<T, MatrixPart<T>>> {
740 let rows = self.rows();
741 let columns = self.columns();
742 fn check_axis(partitions: &[usize], length: usize) {
743 let mut previous: Option<usize> = None;
744 for &index in partitions {
745 assert!(index <= length);
746 previous = match previous {
747 None => Some(index),
748 Some(i) => {
749 assert!(index > i, "{:?} must be ascending", partitions);
750 Some(i)
751 }
752 }
753 }
754 }
755 check_axis(row_partitions, rows);
756 check_axis(column_partitions, columns);
757
758 // There will be one more slice than partitions, since partitions are the boundries
759 // between slices.
760 let row_slices = row_partitions.len() + 1;
761 let column_slices = column_partitions.len() + 1;
762 let total_slices = row_slices * column_slices;
763 let mut slices: Vec<Vec<&mut [T]>> = Vec::with_capacity(total_slices);
764 let (_, mut data) = self.data.split_at_mut(0);
765
766 let mut index = 0;
767 for r in 0..row_slices {
768 let row_index = row_partitions.get(r).cloned().unwrap_or(rows);
769 // Determine how many rows of our matrix we need for the next set of row slices
770 let rows_included = row_index - index;
771 for _ in 0..column_slices {
772 slices.push(Vec::with_capacity(rows_included));
773 }
774 index = row_index;
775
776 for _ in 0..rows_included {
777 // Partition the next row of our matrix along the columns
778 let mut index = 0;
779 for c in 0..column_slices {
780 let column_index = column_partitions.get(c).cloned().unwrap_or(columns);
781 let columns_included = column_index - index;
782 index = column_index;
783 // Split off as many elements as included in this column slice
784 let (slice, rest) = data.split_at_mut(columns_included);
785 // Insert the slice into the slices, we'll push `rows_included` times into
786 // each slice Vec.
787 slices[(r * column_slices) + c].push(slice);
788 data = rest;
789 }
790 }
791 }
792 // rest is now empty, so we can ignore it.
793
794 slices
795 .into_iter()
796 .map(|slices| {
797 let rows = slices.len();
798 let columns = slices.first().map(|columns| columns.len()).unwrap_or(0);
799 if columns == 0 {
800 // We may have allocated N rows but if each column in that row has no size
801 // our actual size is 0x0
802 MatrixView::from(MatrixPart::new(slices, 0, 0))
803 } else {
804 MatrixView::from(MatrixPart::new(slices, rows, columns))
805 }
806 })
807 .collect()
808 }
809
810 /**
811 * Partition a matrix into 4 non overlapping quadrants. Top left starts at 0,0 until
812 * exclusive of row and column, bottom right starts at row and column to the end of the matrix.
813 *
814 * # Panics
815 *
816 * Panics if the row or column are greater than the number of rows or columns in the matrix.
817 *
818 * # Examples
819 *
820 * ```
821 * use easy_ml::matrices::Matrix;
822 * let mut matrix = Matrix::from(vec![
823 * vec![ 0, 1, 2 ],
824 * vec![ 3, 4, 5 ],
825 * vec![ 6, 7, 8 ]
826 * ]);
827 * // Split the matrix at the second row and first column giving 2x1, 2x2, 1x1 and 2x1
828 * // quadrants.
829 * // 0 | 1 2
830 * // 3 | 4 5
831 * // -------
832 * // 6 | 7 8
833 * let mut parts = matrix.partition_quadrants(2, 1);
834 * assert_eq!(parts.top_left, Matrix::column(vec![ 0, 3 ]));
835 * assert_eq!(parts.top_right, Matrix::from(vec![vec![ 1, 2 ], vec![ 4, 5 ]]));
836 * assert_eq!(parts.bottom_left, Matrix::column(vec![ 6 ]));
837 * assert_eq!(parts.bottom_right, Matrix::row(vec![ 7, 8 ]));
838 * // Modify the matrix data independently without worrying about the borrow checker
839 * parts.top_right.map_mut(|x| x + 10);
840 * parts.bottom_left.map_mut(|x| x - 10);
841 * // Drop MatrixQuadrants so we can use the matrix directly again
842 * std::mem::drop(parts);
843 * assert_eq!(matrix, Matrix::from(vec![
844 * vec![ 0, 11, 12 ],
845 * vec![ 3, 14, 15 ],
846 * vec![ -4, 7, 8 ]
847 * ]));
848 * ```
849 */
850 #[track_caller]
851 #[allow(clippy::needless_lifetimes)] // false positive?
852 pub fn partition_quadrants<'a>(
853 &'a mut self,
854 row: Row,
855 column: Column,
856 ) -> MatrixQuadrants<'a, T> {
857 let mut parts = self.partition(&[row], &[column]).into_iter();
858 // We know there will be exactly 4 parts returned by the partition since we provided
859 // 1 row and 1 column to partition ourself into 4 with.
860 MatrixQuadrants {
861 top_left: parts.next().unwrap(),
862 top_right: parts.next().unwrap(),
863 bottom_left: parts.next().unwrap(),
864 bottom_right: parts.next().unwrap(),
865 }
866 }
867
868 /**
869 * Returns a MatrixView giving a view of only the data within the row and column
870 * [IndexRange]s.
871 *
872 * This is a shorthand for constructing the MatrixView from this Matrix.
873 *
874 * ```
875 * use easy_ml::matrices::Matrix;
876 * use easy_ml::matrices::views::{MatrixView, MatrixRange, IndexRange};
877 * let ab = Matrix::from(vec![
878 * vec![ 0, 1, 2, 0 ],
879 * vec![ 3, 4, 5, 1 ]
880 * ]);
881 * let shorter = ab.range(0..2, 1..3);
882 * assert_eq!(
883 * shorter,
884 * Matrix::from(vec![
885 * vec![ 1, 2 ],
886 * vec![ 4, 5 ]
887 * ])
888 * );
889 * ```
890 */
891 pub fn range<R>(&self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, &Matrix<T>>>
892 where
893 R: Into<IndexRange>,
894 {
895 MatrixView::from(MatrixRange::from(self, rows, columns))
896 }
897
898 /**
899 * Returns a MatrixView giving a view of only the data within the row and column
900 * [IndexRange]s. The MatrixRange mutably borrows this Matrix, and can
901 * therefore mutate it.
902 *
903 * This is a shorthand for constructing the MatrixView from this Matrix.
904 */
905 pub fn range_mut<R>(
906 &mut self,
907 rows: R,
908 columns: R,
909 ) -> MatrixView<T, MatrixRange<T, &mut Matrix<T>>>
910 where
911 R: Into<IndexRange>,
912 {
913 MatrixView::from(MatrixRange::from(self, rows, columns))
914 }
915
916 /**
917 * Returns a MatrixView giving a view of only the data within the row and column
918 * [IndexRange]s. The MatrixRange takes ownership of this Matrix, and can
919 * therefore mutate it.
920 *
921 * This is a shorthand for constructing the MatrixView from this Matrix.
922 */
923 pub fn range_owned<R>(self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, Matrix<T>>>
924 where
925 R: Into<IndexRange>,
926 {
927 MatrixView::from(MatrixRange::from(self, rows, columns))
928 }
929
930 /**
931 * Returns a MatrixView giving a view of only the data outside the row and column
932 * [IndexRange]s.
933 *
934 * This is a shorthand for constructing the MatrixView from this Matrix.
935 *
936 * ```
937 * use easy_ml::matrices::Matrix;
938 * use easy_ml::matrices::views::{MatrixView, MatrixMask, IndexRange};
939 * let ab = Matrix::from(vec![
940 * vec![ 0, 1, 2, 0 ],
941 * vec![ 3, 4, 5, 1 ]
942 * ]);
943 * let shorter = ab.mask(0..1, 1..3);
944 * assert_eq!(
945 * shorter,
946 * Matrix::from(vec![
947 * vec![ 3, 1 ]
948 * ])
949 * );
950 * ```
951 */
952 pub fn mask<R>(&self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, &Matrix<T>>>
953 where
954 R: Into<IndexRange>,
955 {
956 MatrixView::from(MatrixMask::from(self, rows, columns))
957 }
958
959 /**
960 * Returns a MatrixView giving a view of only the data outside the row and column
961 * [IndexRange]s. The MatrixMask mutably borrows this Matrix, and can
962 * therefore mutate it.
963 *
964 * This is a shorthand for constructing the MatrixView from this Matrix.
965 */
966 pub fn mask_mut<R>(
967 &mut self,
968 rows: R,
969 columns: R,
970 ) -> MatrixView<T, MatrixMask<T, &mut Matrix<T>>>
971 where
972 R: Into<IndexRange>,
973 {
974 MatrixView::from(MatrixMask::from(self, rows, columns))
975 }
976
977 /**
978 * Returns a MatrixView giving a view of only the data outside the row and column
979 * [IndexRange]s. The MatrixMask takes ownership of this Matrix, and can
980 * therefore mutate it.
981 *
982 * This is a shorthand for constructing the MatrixView from this Matrix.
983 */
984 pub fn mask_owned<R>(self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, Matrix<T>>>
985 where
986 R: Into<IndexRange>,
987 {
988 MatrixView::from(MatrixMask::from(self, rows, columns))
989 }
990
991 /**
992 * Returns a MatrixView with the rows and columns specified reversed in iteration
993 * order. The data of this matrix and the dimension lengths remain unchanged.
994 *
995 * This is a shorthand for constructing the MatrixView from this Matrix.
996 *
997 * ```
998 * use easy_ml::matrices::Matrix;
999 * use easy_ml::matrices::views::{MatrixView, MatrixReverse, Reverse};
1000 * let ab = Matrix::from(vec![
1001 * vec![ 0, 1, 2 ],
1002 * vec![ 3, 4, 5 ]
1003 * ]);
1004 * let reversed = ab.reverse(Reverse { rows: true, ..Default::default() });
1005 * let also_reversed = MatrixView::from(
1006 * MatrixReverse::from(&ab, Reverse { rows: true, columns: false })
1007 * );
1008 * assert_eq!(reversed, also_reversed);
1009 * assert_eq!(
1010 * reversed,
1011 * Matrix::from(vec![
1012 * vec![ 3, 4, 5 ],
1013 * vec![ 0, 1, 2 ]
1014 * ])
1015 * );
1016 * ```
1017 */
1018 pub fn reverse(&self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, &Matrix<T>>> {
1019 MatrixView::from(MatrixReverse::from(self, reverse))
1020 }
1021
1022 /**
1023 * Returns a MatrixView with the rows and columns specified reversed in iteration
1024 * order. The data of this matrix and the dimension lengths remain unchanged. The MatrixReverse
1025 * mutably borrows this Matrix, and can therefore mutate it
1026 *
1027 * This is a shorthand for constructing the MatrixView from this Matrix.
1028 */
1029 pub fn reverse_mut(
1030 &mut self,
1031 reverse: Reverse,
1032 ) -> MatrixView<T, MatrixReverse<T, &mut Matrix<T>>> {
1033 MatrixView::from(MatrixReverse::from(self, reverse))
1034 }
1035
1036 /**
1037 * Returns a MatrixView with the rows and columns specified reversed in iteration
1038 * order. The data of this matrix and the dimension lengths remain unchanged. The MatrixReverse
1039 * takes ownership of this Matrix, and can therefore mutate it
1040 *
1041 * This is a shorthand for constructing the MatrixView from this Matrix.
1042 */
1043 pub fn reverse_owned(self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, Matrix<T>>> {
1044 MatrixView::from(MatrixReverse::from(self, reverse))
1045 }
1046
1047 /**
1048 * Converts this Matrix into a 2 dimensional Tensor with the provided dimension names.
1049 *
1050 * This is a wrapper around the `TryFrom<(Matrix<T>, [Dimension; 2])>` implementation.
1051 *
1052 * The Tensor will have the data in the same order, a shape with lengths of `self.rows()` then
1053 * `self.columns()` and the provided dimension names respectively.
1054 *
1055 * Result::Err is returned if the `rows` and `columns` dimension names are the same.
1056 */
1057 pub fn into_tensor(
1058 self,
1059 rows: crate::tensors::Dimension,
1060 columns: crate::tensors::Dimension,
1061 ) -> Result<crate::tensors::Tensor<T, 2>, crate::tensors::InvalidShapeError<2>> {
1062 (self, [rows, columns]).try_into()
1063 }
1064}
1065
1066/**
1067 * Methods for matrices with types that can be copied, but still not neccessarily numerical.
1068 */
1069impl<T: Clone> Matrix<T> {
1070 /**
1071 * Computes and returns the transpose of this matrix
1072 *
1073 * ```
1074 * use easy_ml::matrices::Matrix;
1075 * let x = Matrix::from(vec![
1076 * vec![ 1, 2 ],
1077 * vec![ 3, 4 ]]);
1078 * let y = Matrix::from(vec![
1079 * vec![ 1, 3 ],
1080 * vec![ 2, 4 ]]);
1081 * assert_eq!(x.transpose(), y);
1082 * ```
1083 */
1084 pub fn transpose(&self) -> Matrix<T> {
1085 Matrix::from_fn((self.columns(), self.rows()), |(column, row)| {
1086 self.get(row, column)
1087 })
1088 }
1089
1090 /**
1091 * Transposes the matrix in place (if it is square).
1092 *
1093 * ```
1094 * use easy_ml::matrices::Matrix;
1095 * let mut x = Matrix::from(vec![
1096 * vec![ 1, 2 ],
1097 * vec![ 3, 4 ]]);
1098 * x.transpose_mut();
1099 * let y = Matrix::from(vec![
1100 * vec![ 1, 3 ],
1101 * vec![ 2, 4 ]]);
1102 * assert_eq!(x, y);
1103 * ```
1104 *
1105 * Note: None square matrices were erroneously not supported in previous versions (<=1.8.0) and
1106 * could be incorrectly mutated. This method will now correctly transpose non square matrices
1107 * by not attempting to transpose them in place.
1108 */
1109 pub fn transpose_mut(&mut self) {
1110 if self.rows() != self.columns() {
1111 let transposed = self.transpose();
1112 self.data = transposed.data;
1113 self.rows = transposed.rows;
1114 self.columns = transposed.columns;
1115 } else {
1116 for i in 0..self.rows() {
1117 for j in 0..self.columns() {
1118 if i > j {
1119 continue;
1120 }
1121 let temp = self.get(i, j);
1122 self.set(i, j, self.get(j, i));
1123 self.set(j, i, temp);
1124 }
1125 }
1126 }
1127 }
1128
1129 /**
1130 * Returns an iterator over a column vector in this matrix. Columns are 0 indexed.
1131 *
1132 * If you have a matrix such as:
1133 * ```ignore
1134 * [
1135 * 1, 2, 3
1136 * 4, 5, 6
1137 * 7, 8, 9
1138 * ]
1139 * ```
1140 * then a column of 0, 1, and 2 will yield [1, 4, 7], [2, 5, 8] and [3, 6, 9]
1141 * respectively. If you do not need to copy the elements use
1142 * [`column_reference_iter`](Matrix::column_reference_iter) instead.
1143 *
1144 * # Panics
1145 *
1146 * Panics if the column does not exist in this matrix.
1147 */
1148 #[track_caller]
1149 pub fn column_iter(&self, column: Column) -> ColumnIterator<T> {
1150 ColumnIterator::new(self, column)
1151 }
1152
1153 /**
1154 * Returns an iterator over a row vector in this matrix. Rows are 0 indexed.
1155 *
1156 * If you have a matrix such as:
1157 * ```ignore
1158 * [
1159 * 1, 2, 3
1160 * 4, 5, 6
1161 * 7, 8, 9
1162 * ]
1163 * ```
1164 * then a row of 0, 1, and 2 will yield [1, 2, 3], [4, 5, 6] and [7, 8, 9]
1165 * respectively. If you do not need to copy the elements use
1166 * [`row_reference_iter`](Matrix::row_reference_iter) instead.
1167 *
1168 * # Panics
1169 *
1170 * Panics if the row does not exist in this matrix.
1171 */
1172 #[track_caller]
1173 pub fn row_iter(&self, row: Row) -> RowIterator<T> {
1174 RowIterator::new(self, row)
1175 }
1176
1177 /**
1178 * Returns a column major iterator over all values in this matrix, proceeding through each
1179 * column in order.
1180 *
1181 * If you have a matrix such as:
1182 * ```ignore
1183 * [
1184 * 1, 2
1185 * 3, 4
1186 * ]
1187 * ```
1188 * then the iterator will yield [1, 3, 2, 4]. If you do not need to copy the
1189 * elements use [`column_major_reference_iter`](Matrix::column_major_reference_iter) instead.
1190 */
1191 pub fn column_major_iter(&self) -> ColumnMajorIterator<T> {
1192 ColumnMajorIterator::new(self)
1193 }
1194
1195 /**
1196 * Returns a row major iterator over all values in this matrix, proceeding through each
1197 * row in order.
1198 *
1199 * If you have a matrix such as:
1200 * ```ignore
1201 * [
1202 * 1, 2
1203 * 3, 4
1204 * ]
1205 * ```
1206 * then the iterator will yield [1, 2, 3, 4]. If you do not need to copy the
1207 * elements use [`row_major_reference_iter`](Matrix::row_major_reference_iter) instead.
1208 */
1209 pub fn row_major_iter(&self) -> RowMajorIterator<T> {
1210 RowMajorIterator::new(self)
1211 }
1212
1213 /**
1214 * Returns a iterator over the main diagonal of this matrix.
1215 *
1216 * If you have a matrix such as:
1217 * ```ignore
1218 * [
1219 * 1, 2
1220 * 3, 4
1221 * ]
1222 * ```
1223 * then the iterator will yield [1, 4]. If you do not need to copy the
1224 * elements use [`diagonal_reference_iter`](Matrix::diagonal_reference_iter) instead.
1225 *
1226 * # Examples
1227 *
1228 * Computing a [trace](https://en.wikipedia.org/wiki/Trace_(linear_algebra))
1229 * ```
1230 * use easy_ml::matrices::Matrix;
1231 * let matrix = Matrix::from(vec![
1232 * vec![ 1, 2, 3 ],
1233 * vec![ 4, 5, 6 ],
1234 * vec![ 7, 8, 9 ],
1235 * ]);
1236 * let trace: i32 = matrix.diagonal_iter().sum();
1237 * assert_eq!(trace, 1 + 5 + 9);
1238 * ```
1239 */
1240 pub fn diagonal_iter(&self) -> DiagonalIterator<T> {
1241 DiagonalIterator::new(self)
1242 }
1243
1244 /**
1245 * Creates a matrix of the provided size with all elements initialised to the provided value
1246 *
1247 * # Panics
1248 *
1249 * Panics if no values are provided. Note: this method erroneously did not validate its inputs
1250 * in Easy ML versions up to and including 1.7.0
1251 */
1252 #[track_caller]
1253 pub fn empty(value: T, size: (Row, Column)) -> Matrix<T> {
1254 assert!(size.0 > 0 && size.1 > 0, "Size must be at least 1x1");
1255 Matrix {
1256 data: vec![value; size.0 * size.1],
1257 rows: size.0,
1258 columns: size.1,
1259 }
1260 }
1261
1262 /**
1263 * Gets a copy of the value at this row and column. Rows and Columns are 0 indexed.
1264 *
1265 * # Panics
1266 *
1267 * Panics if the index is out of range.
1268 */
1269 #[track_caller]
1270 pub fn get(&self, row: Row, column: Column) -> T {
1271 assert!(
1272 row < self.rows(),
1273 "Row out of index, only have {} rows",
1274 self.rows()
1275 );
1276 assert!(
1277 column < self.columns(),
1278 "Column out of index, only have {} columns",
1279 self.columns()
1280 );
1281 self.data[self.get_index(row, column)].clone()
1282 }
1283
1284 /**
1285 * Similar to matrix.get(0, 0) in that this returns the element in the first
1286 * row and first column, except that this method will panic if the matrix is
1287 * not 1x1.
1288 *
1289 * This is provided as a convenience function when you want to convert a unit matrix
1290 * to a scalar, such as after taking a dot product of two vectors.
1291 *
1292 * # Example
1293 *
1294 * ```
1295 * use easy_ml::matrices::Matrix;
1296 * let x = Matrix::column(vec![ 1.0, 2.0, 3.0 ]);
1297 * let sum_of_squares: f64 = (x.transpose() * x).scalar();
1298 * ```
1299 *
1300 * # Panics
1301 *
1302 * Panics if the matrix is not 1x1
1303 */
1304 #[track_caller]
1305 pub fn scalar(&self) -> T {
1306 assert!(
1307 self.rows() == 1,
1308 "Cannot treat matrix as scalar as it has more than one row"
1309 );
1310 assert!(
1311 self.columns() == 1,
1312 "Cannot treat matrix as scalar as it has more than one column"
1313 );
1314 self.get(0, 0)
1315 }
1316
1317 /**
1318 * Applies a function to all values in the matrix, modifying
1319 * the matrix in place.
1320 */
1321 pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1322 for value in self.data.iter_mut() {
1323 *value = mapping_function(value.clone());
1324 }
1325 }
1326
1327 /**
1328 * Applies a function to all values and each value's index in the
1329 * matrix, modifying the matrix in place.
1330 */
1331 pub fn map_mut_with_index(&mut self, mapping_function: impl Fn(T, Row, Column) -> T) {
1332 self.row_major_reference_mut_iter()
1333 .with_index()
1334 .for_each(|((i, j), x)| {
1335 *x = mapping_function(x.clone(), i, j);
1336 });
1337 }
1338
1339 /**
1340 * Creates and returns a new matrix with all values from the original with the
1341 * function applied to each. This can be used to change the type of the matrix
1342 * such as creating a mask:
1343 * ```
1344 * use easy_ml::matrices::Matrix;
1345 * let x = Matrix::from(vec![
1346 * vec![ 0.0, 1.2 ],
1347 * vec![ 5.8, 6.9 ]]);
1348 * let y = x.map(|element| element > 2.0);
1349 * let result = Matrix::from(vec![
1350 * vec![ false, false ],
1351 * vec![ true, true ]]);
1352 * assert_eq!(&y, &result);
1353 * ```
1354 */
1355 pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Matrix<U>
1356 where
1357 U: Clone,
1358 {
1359 let mapped = self
1360 .data
1361 .iter()
1362 .map(|x| mapping_function(x.clone()))
1363 .collect();
1364 Matrix::from_flat_row_major(self.size(), mapped)
1365 }
1366
1367 /**
1368 * Creates and returns a new matrix with all values from the original
1369 * and the index of each value mapped by a function. This can be used
1370 * to perform elementwise operations that are not defined on the
1371 * Matrix type itself.
1372 *
1373 * # Exmples
1374 *
1375 * Matrix elementwise division:
1376 *
1377 * ```
1378 * use easy_ml::matrices::Matrix;
1379 * let x = Matrix::from(vec![
1380 * vec![ 9.0, 2.0 ],
1381 * vec![ 4.0, 3.0 ]]);
1382 * let y = Matrix::from(vec![
1383 * vec![ 3.0, 2.0 ],
1384 * vec![ 1.0, 3.0 ]]);
1385 * let z = x.map_with_index(|x, row, column| x / y.get(row, column));
1386 * let result = Matrix::from(vec![
1387 * vec![ 3.0, 1.0 ],
1388 * vec![ 4.0, 1.0 ]]);
1389 * assert_eq!(&z, &result);
1390 * ```
1391 */
1392 pub fn map_with_index<U>(&self, mapping_function: impl Fn(T, Row, Column) -> U) -> Matrix<U>
1393 where
1394 U: Clone,
1395 {
1396 let mapped = self
1397 .row_major_iter()
1398 .with_index()
1399 .map(|((i, j), x)| mapping_function(x, i, j))
1400 .collect();
1401 Matrix::from_flat_row_major(self.size(), mapped)
1402 }
1403
1404 /**
1405 * Inserts a new row into the Matrix at the provided index,
1406 * shifting other rows to the right and filling all entries with the
1407 * provided value. Rows are 0 indexed.
1408 *
1409 * # Panics
1410 *
1411 * This will panic if the row is greater than the number of rows in the matrix.
1412 */
1413 #[track_caller]
1414 pub fn insert_row(&mut self, row: Row, value: T) {
1415 assert!(
1416 row <= self.rows(),
1417 "Row to insert must be <= to {}",
1418 self.rows()
1419 );
1420 for column in 0..self.columns() {
1421 self.data.insert(self.get_index(row, column), value.clone());
1422 }
1423 self.rows += 1;
1424 }
1425
1426 /**
1427 * Inserts a new row into the Matrix at the provided index, shifting other rows
1428 * to the right and filling all entries with the values from the iterator in sequence.
1429 * Rows are 0 indexed.
1430 *
1431 * # Panics
1432 *
1433 * This will panic if the row is greater than the number of rows in the matrix,
1434 * or if the iterator has fewer elements than `self.columns()`.
1435 *
1436 * Example of duplicating a row:
1437 * ```
1438 * use easy_ml::matrices::Matrix;
1439 * let x: Matrix<u8> = Matrix::row(vec![ 1, 2, 3 ]);
1440 * let mut y = x.clone();
1441 * // duplicate the first row as the second row
1442 * y.insert_row_with(1, x.row_iter(0));
1443 * assert_eq!((2, 3), y.size());
1444 * let mut values = y.column_major_iter();
1445 * assert_eq!(Some(1), values.next());
1446 * assert_eq!(Some(1), values.next());
1447 * assert_eq!(Some(2), values.next());
1448 * assert_eq!(Some(2), values.next());
1449 * assert_eq!(Some(3), values.next());
1450 * assert_eq!(Some(3), values.next());
1451 * assert_eq!(None, values.next());
1452 * ```
1453 */
1454 #[track_caller]
1455 pub fn insert_row_with<I>(&mut self, row: Row, mut values: I)
1456 where
1457 I: Iterator<Item = T>,
1458 {
1459 assert!(
1460 row <= self.rows(),
1461 "Row to insert must be <= to {}",
1462 self.rows()
1463 );
1464 for column in 0..self.columns() {
1465 self.data.insert(
1466 self.get_index(row, column),
1467 values.next().unwrap_or_else(|| {
1468 panic!("At least {} values must be provided", self.columns())
1469 }),
1470 );
1471 }
1472 self.rows += 1;
1473 }
1474
1475 /**
1476 * Inserts a new column into the Matrix at the provided index, shifting other
1477 * columns to the right and filling all entries with the provided value.
1478 * Columns are 0 indexed.
1479 *
1480 * # Panics
1481 *
1482 * This will panic if the column is greater than the number of columns in the matrix.
1483 */
1484 #[track_caller]
1485 pub fn insert_column(&mut self, column: Column, value: T) {
1486 assert!(
1487 column <= self.columns(),
1488 "Column to insert must be <= to {}",
1489 self.columns()
1490 );
1491 for row in (0..self.rows()).rev() {
1492 self.data.insert(self.get_index(row, column), value.clone());
1493 }
1494 self.columns += 1;
1495 }
1496
1497 /**
1498 * Inserts a new column into the Matrix at the provided index, shifting other columns
1499 * to the right and filling all entries with the values from the iterator in sequence.
1500 * Columns are 0 indexed.
1501 *
1502 * # Panics
1503 *
1504 * This will panic if the column is greater than the number of columns in the matrix,
1505 * or if the iterator has fewer elements than `self.rows()`.
1506 *
1507 * Example of duplicating a column:
1508 * ```
1509 * use easy_ml::matrices::Matrix;
1510 * let x: Matrix<u8> = Matrix::column(vec![ 1, 2, 3 ]);
1511 * let mut y = x.clone();
1512 * // duplicate the first column as the second column
1513 * y.insert_column_with(1, x.column_iter(0));
1514 * assert_eq!((3, 2), y.size());
1515 * let mut values = y.column_major_iter();
1516 * assert_eq!(Some(1), values.next());
1517 * assert_eq!(Some(2), values.next());
1518 * assert_eq!(Some(3), values.next());
1519 * assert_eq!(Some(1), values.next());
1520 * assert_eq!(Some(2), values.next());
1521 * assert_eq!(Some(3), values.next());
1522 * assert_eq!(None, values.next());
1523 * ```
1524 */
1525 #[track_caller]
1526 pub fn insert_column_with<I>(&mut self, column: Column, values: I)
1527 where
1528 I: Iterator<Item = T>,
1529 {
1530 assert!(
1531 column <= self.columns(),
1532 "Column to insert must be <= to {}",
1533 self.columns()
1534 );
1535 let mut array_values = values.collect::<Vec<T>>();
1536 assert!(
1537 array_values.len() >= self.rows(),
1538 "At least {} values must be provided",
1539 self.rows()
1540 );
1541 for row in (0..self.rows()).rev() {
1542 self.data
1543 .insert(self.get_index(row, column), array_values.pop().unwrap());
1544 }
1545 self.columns += 1;
1546 }
1547
1548 /**
1549 * Makes a copy of this matrix shrunk down in size according to the slice. See
1550 * [retain_mut](Matrix::retain_mut()).
1551 */
1552 pub fn retain(&self, slice: Slice2D) -> Matrix<T> {
1553 let mut retained = self.clone();
1554 retained.retain_mut(slice);
1555 retained
1556 }
1557}
1558
1559/**
1560 * Any matrix of a Cloneable type implements Clone.
1561 */
1562impl<T: Clone> Clone for Matrix<T> {
1563 fn clone(&self) -> Self {
1564 self.map(|element| element)
1565 }
1566}
1567
1568/**
1569 * Any matrix of a Displayable type implements Display
1570 *
1571 * You can control the precision of the formatting using format arguments, i.e.
1572 * `format!("{:.3}", matrix)`
1573 */
1574impl<T: std::fmt::Display> std::fmt::Display for Matrix<T> {
1575 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1576 crate::matrices::views::format_view(self, f)
1577 }
1578}
1579
1580/**
1581 * Any matrix and two different dimension names can be converted to a 2 dimensional tensor with
1582 * the same number of rows and columns.
1583 *
1584 * Conversion will fail if the dimension names for `self.rows()` and `self.columns()` respectively
1585 * are the same.
1586 */
1587impl<T> TryFrom<(Matrix<T>, [crate::tensors::Dimension; 2])> for crate::tensors::Tensor<T, 2> {
1588 type Error = crate::tensors::InvalidShapeError<2>;
1589
1590 fn try_from(value: (Matrix<T>, [crate::tensors::Dimension; 2])) -> Result<Self, Self::Error> {
1591 let (matrix, [row_name, column_name]) = value;
1592 let shape = [(row_name, matrix.rows), (column_name, matrix.columns)];
1593 let check = crate::tensors::InvalidShapeError::new(shape);
1594 if !check.is_valid() {
1595 return Err(check);
1596 }
1597 // Now we know the shape is valid, we can call the standard Tensor constructor knowing
1598 // it won't fail since our data length will match the size of our shape.
1599 Ok(crate::tensors::Tensor::from(shape, matrix.data))
1600 }
1601}
1602
1603/**
1604 * Methods for matrices with numerical types, such as f32 or f64.
1605 *
1606 * Note that unsigned integers are not Numeric because they do not
1607 * implement [Neg](std::ops::Neg). You must first
1608 * wrap unsigned integers via [Wrapping](std::num::Wrapping) or [Saturating](std::num::Saturating).
1609 *
1610 * While these methods will all be defined on signed integer types as well, such as i16 or i32,
1611 * in many cases integers cannot be used sensibly in these computations. If you
1612 * have a matrix of type i8 for example, you should consider mapping it into a floating
1613 * type before doing heavy linear algebra maths on it.
1614 *
1615 * Determinants can be computed without loss of precision using sufficiently large signed
1616 * integers because the only operations performed on the elements are addition, subtraction
1617 * and mulitplication. However the inverse of a matrix such as
1618 *
1619 * ```ignore
1620 * [
1621 * 4, 7
1622 * 2, 8
1623 * ]
1624 * ```
1625 *
1626 * is
1627 *
1628 * ```ignore
1629 * [
1630 * 0.6, -0.7,
1631 * -0.2, 0.4
1632 * ]
1633 * ```
1634 *
1635 * which requires a type that supports decimals to accurately represent.
1636 *
1637 * Mapping matrix type example:
1638 * ```
1639 * use easy_ml::matrices::Matrix;
1640 * use std::num::Wrapping;
1641 *
1642 * let matrix: Matrix<u8> = Matrix::from(vec![
1643 * vec![ 2, 3 ],
1644 * vec![ 6, 0 ]
1645 * ]);
1646 * // determinant is not defined on this matrix because u8 is not Numeric
1647 * // println!("{:?}", matrix.determinant()); // won't compile
1648 * // however Wrapping<u8> is numeric
1649 * let matrix = matrix.map(|element| Wrapping(element));
1650 * println!("{:?}", matrix.determinant()); // -> 238 (overflow)
1651 * println!("{:?}", matrix.map(|element| element.0 as i16).determinant()); // -> -18
1652 * println!("{:?}", matrix.map(|element| element.0 as f32).determinant()); // -> -18.0
1653 * ```
1654 */
1655impl<T: Numeric> Matrix<T>
1656where
1657 for<'a> &'a T: NumericRef<T>,
1658{
1659 /**
1660 * Returns the determinant of this square matrix, or None if the matrix
1661 * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant())
1662 */
1663 pub fn determinant(&self) -> Option<T> {
1664 linear_algebra::determinant::<T>(self)
1665 }
1666
1667 /**
1668 * Computes the inverse of a matrix provided that it exists. To have an inverse a
1669 * matrix must be square (same number of rows and columns) and it must also have a
1670 * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse())
1671 */
1672 pub fn inverse(&self) -> Option<Matrix<T>> {
1673 linear_algebra::inverse::<T>(self)
1674 }
1675
1676 /**
1677 * Computes the covariance matrix for this NxM feature matrix, in which
1678 * each N'th row has M features to find the covariance and variance of. See
1679 * [`linear_algebra`](super::linear_algebra::covariance_column_features())
1680 */
1681 pub fn covariance_column_features(&self) -> Matrix<T> {
1682 linear_algebra::covariance_column_features::<T>(self)
1683 }
1684
1685 /**
1686 * Computes the covariance matrix for this NxM feature matrix, in which
1687 * each M'th column has N features to find the covariance and variance of. See
1688 * [`linear_algebra`](super::linear_algebra::covariance_row_features())
1689 */
1690 pub fn covariance_row_features(&self) -> Matrix<T> {
1691 linear_algebra::covariance_row_features::<T>(self)
1692 }
1693}
1694
1695/**
1696 * Methods for matrices with numerical real valued types, such as f32 or f64.
1697 *
1698 * This excludes signed and unsigned integers as they do not support decimal
1699 * precision and hence can't be used for operations like square roots.
1700 *
1701 * Third party fixed precision and infinite precision decimal types should
1702 * be able to implement all of the methods for [Real] and then utilise these functions.
1703 */
1704impl<T: Real> Matrix<T>
1705where
1706 for<'a> &'a T: RealRef<T>,
1707{
1708 /**
1709 * Computes the [L2 norm](https://en.wikipedia.org/wiki/Euclidean_vector#Length)
1710 * of this row or column vector, also referred to as the length or magnitude,
1711 * and written as ||x||, or sometimes |x|.
1712 *
1713 * ||**a**|| = sqrt(a<sub>1</sub><sup>2</sup> + a<sub>2</sub><sup>2</sup> + a<sub>3</sub><sup>2</sup>...) = sqrt(**a**<sup>T</sup> * **a**)
1714 *
1715 * This is a shorthand for `(x.transpose() * x).scalar().sqrt()` for
1716 * column vectors and `(x * x.transpose()).scalar().sqrt()` for row vectors, ie
1717 * the square root of the dot product of a vector with itself.
1718 *
1719 * The euclidean length can be used to compute a
1720 * [unit vector](https://en.wikipedia.org/wiki/Unit_vector), that is, a
1721 * vector with length of 1. This should not be confused with a unit matrix,
1722 * which is another name for an identity matrix.
1723 *
1724 * ```
1725 * use easy_ml::matrices::Matrix;
1726 * let a = Matrix::column(vec![ 1.0, 2.0, 3.0 ]);
1727 * let length = a.euclidean_length(); // (1^2 + 2^2 + 3^2)^0.5
1728 * let unit = a / length;
1729 * assert_eq!(unit.euclidean_length(), 1.0);
1730 * ```
1731 *
1732 * # Panics
1733 *
1734 * If the matrix is not a vector, ie if it has more than one row and more than one
1735 * column.
1736 */
1737 #[track_caller]
1738 pub fn euclidean_length(&self) -> T {
1739 if self.columns() == 1 {
1740 // column vector
1741 (self.transpose() * self).scalar().sqrt()
1742 } else if self.rows() == 1 {
1743 // row vector
1744 (self * self.transpose()).scalar().sqrt()
1745 } else {
1746 panic!(
1747 "Cannot compute unit vector of a non vector, rows: {}, columns: {}",
1748 self.rows(),
1749 self.columns()
1750 );
1751 }
1752 }
1753}
1754
1755// FIXME: want this to be callable in the main numeric impl block
1756impl<T: Numeric> Matrix<T> {
1757 /**
1758 * Creates a diagonal matrix of the provided size with the diagonal elements
1759 * set to the provided value and all other elements in the matrix set to 0.
1760 * A diagonal matrix is always square.
1761 *
1762 * The size is still taken as a tuple to facilitate creating a diagonal matrix
1763 * from the dimensionality of an existing one. If the provided value is 1 then
1764 * this will create an identity matrix.
1765 *
1766 * A 3 x 3 identity matrix:
1767 * ```ignore
1768 * [
1769 * 1, 0, 0
1770 * 0, 1, 0
1771 * 0, 0, 1
1772 * ]
1773 * ```
1774 *
1775 * # Panics
1776 *
1777 * If the provided size is not square.
1778 */
1779 #[track_caller]
1780 pub fn diagonal(value: T, size: (Row, Column)) -> Matrix<T> {
1781 assert!(size.0 == size.1);
1782 let mut matrix = Matrix::empty(T::zero(), size);
1783 for i in 0..size.0 {
1784 matrix.set(i, i, value.clone());
1785 }
1786 matrix
1787 }
1788
1789 /**
1790 * Creates a diagonal matrix with the elements along the diagonal set to the
1791 * provided values and all other elements in the matrix set to 0.
1792 * A diagonal matrix is always square.
1793 *
1794 * Examples
1795 *
1796 * ```
1797 * use easy_ml::matrices::Matrix;
1798 * let matrix = Matrix::from_diagonal(vec![ 1, 1, 1 ]);
1799 * assert_eq!(matrix.size(), (3, 3));
1800 * let copy = Matrix::from_diagonal(matrix.diagonal_iter().collect());
1801 * assert_eq!(matrix, copy);
1802 * assert_eq!(matrix, Matrix::from(vec![
1803 * vec![ 1, 0, 0 ],
1804 * vec![ 0, 1, 0 ],
1805 * vec![ 0, 0, 1 ],
1806 * ]))
1807 * ```
1808 */
1809 pub fn from_diagonal(values: Vec<T>) -> Matrix<T> {
1810 let mut matrix = Matrix::empty(T::zero(), (values.len(), values.len()));
1811 for (i, element) in values.into_iter().enumerate() {
1812 matrix.set(i, i, element);
1813 }
1814 matrix
1815 }
1816}
1817
1818/**
1819 * PartialEq is implemented as two matrices are equal if and only if all their elements
1820 * are equal and they have the same size.
1821 */
1822impl<T: PartialEq> PartialEq for Matrix<T> {
1823 #[inline]
1824 fn eq(&self, other: &Self) -> bool {
1825 if self.rows() != other.rows() {
1826 return false;
1827 }
1828 if self.columns() != other.columns() {
1829 return false;
1830 }
1831 // perform elementwise check, return true only if every element in
1832 // each matrix is the same
1833 self.data.iter().zip(other.data.iter()).all(|(x, y)| x == y)
1834 }
1835}
1836
1837#[test]
1838fn test_sync() {
1839 fn assert_sync<T: Sync>() {}
1840 assert_sync::<Matrix<f64>>();
1841}
1842
1843#[test]
1844fn test_send() {
1845 fn assert_send<T: Send>() {}
1846 assert_send::<Matrix<f64>>();
1847}
1848
1849#[cfg(feature = "serde")]
1850mod serde_impls {
1851 use crate::matrices::{Column, Matrix, Row};
1852 use serde::{Deserialize, Deserializer};
1853
1854 #[derive(Deserialize)]
1855 #[serde(rename = "Matrix")]
1856 struct MatrixDeserialize<T> {
1857 data: Vec<T>,
1858 rows: Row,
1859 columns: Column,
1860 }
1861
1862 impl<'de, T> Deserialize<'de> for Matrix<T>
1863 where
1864 T: Deserialize<'de>,
1865 {
1866 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1867 where
1868 D: Deserializer<'de>,
1869 {
1870 MatrixDeserialize::<T>::deserialize(deserializer).map(|d| {
1871 // Safety: Use the no copy constructor that performs validation to prevent invalid
1872 // serialized data being created as a Matrix, which would then break all the
1873 // code that's relying on these invariants.
1874 Matrix::from_flat_row_major((d.rows, d.columns), d.data)
1875 })
1876 }
1877 }
1878}
1879
1880#[cfg(feature = "serde")]
1881#[test]
1882fn test_serialize() {
1883 fn assert_serialize<T: Serialize>() {}
1884 assert_serialize::<Matrix<f64>>();
1885}
1886
1887#[cfg(feature = "serde")]
1888#[test]
1889fn test_deserialize() {
1890 use serde::Deserialize;
1891 fn assert_deserialize<'de, T: Deserialize<'de>>() {}
1892 assert_deserialize::<Matrix<f64>>();
1893}
1894
1895#[cfg(feature = "serde")]
1896#[test]
1897fn test_serialization_deserialization_loop() {
1898 #[rustfmt::skip]
1899 let matrix = Matrix::from(vec![
1900 vec![1, 2, 3, 4],
1901 vec![5, 6, 7, 8],
1902 vec![9, 10, 11, 12],
1903 ]);
1904 let encoded = toml::to_string(&matrix).unwrap();
1905 assert_eq!(
1906 encoded,
1907 r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1908rows = 3
1909columns = 4
1910"#,
1911 );
1912 let parsed: Result<Matrix<i32>, _> = toml::from_str(&encoded);
1913 assert!(parsed.is_ok());
1914 assert_eq!(matrix, parsed.unwrap())
1915}
1916
1917#[cfg(feature = "serde")]
1918#[test]
1919#[should_panic]
1920fn test_deserialization_validation() {
1921 let _result: Result<Matrix<i32>, _> = toml::from_str(
1922 r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1923rows = 3
1924columns = 3
1925"#,
1926 );
1927}
1928
1929#[test]
1930fn test_indexing() {
1931 let a = Matrix::from(vec![vec![1, 2], vec![3, 4]]);
1932 assert_eq!(a.get_index(0, 1), 1);
1933 assert_eq!(a.get_row_column(1), (0, 1));
1934 assert_eq!(a.get(0, 1), 2);
1935 let b = Matrix::from(vec![vec![1, 2, 3], vec![5, 6, 7]]);
1936 assert_eq!(b.get_index(1, 2), 5);
1937 assert_eq!(b.get_row_column(5), (1, 2));
1938 assert_eq!(b.get(1, 2), 7);
1939 assert_eq!(
1940 Matrix::from(vec![vec![0, 0], vec![0, 0], vec![0, 0]])
1941 .map_with_index(|_, r, c| format!("{:?}x{:?}", r, c)),
1942 Matrix::from(vec![
1943 vec!["0x0", "0x1"],
1944 vec!["1x0", "1x1"],
1945 vec!["2x0", "2x1"]
1946 ])
1947 .map(|x| x.to_owned())
1948 );
1949}