easy_ml/matrices/
mod.rs

1/*!
2 * Generic matrix type.
3 *
4 * Matrices are generic over some type `T`. If `T` is [Numeric](super::numeric) then
5 * the matrix can be used in a mathematical way.
6 */
7
8#[cfg(feature = "serde")]
9use serde::Serialize;
10
11mod errors;
12pub mod iterators;
13pub mod operations;
14pub mod slices;
15pub mod views;
16
17pub use errors::ScalarConversionError;
18
19use crate::linear_algebra;
20use crate::matrices::iterators::*;
21use crate::matrices::slices::Slice2D;
22use crate::matrices::views::{
23    IndexRange, MatrixMask, MatrixPart, MatrixQuadrants, MatrixRange, MatrixReverse, MatrixView,
24    Reverse,
25};
26use crate::numeric::extra::{Real, RealRef};
27use crate::numeric::{Numeric, NumericRef};
28
29/**
30 * A general purpose matrix of some type. This type may implement
31 * no traits, in which case the matrix will be rather useless. If the
32 * type implements [`Clone`] most storage and accessor methods are defined and if the type
33 * implements [`Numeric`](super::numeric) then the matrix can be used in
34 * a mathematical way.
35 *
36 * When doing numeric operations with Matrices you should be careful to not
37 * consume a matrix by accidentally using it by value. All the operations are
38 * also defined on references to matrices so you should favor `&x * &y` style
39 * notation for matrices you intend to continue using. There are also convenience
40 * operations defined for a matrix and a scalar.
41 *
42 * # Matrix size invariants
43 *
44 * Matrices must always be at least 1x1. You cannot construct a matrix with no rows or
45 * no columns, and any function that resizes matrices will error if you try to use it
46 * in a way that would construct a 0x1, 1x0, or 0x0 matrix. The maximum size of a matrix
47 * is dependent on the platform's `std::isize::MAX` value. Matrices with dimensions NxM
48 * such that N * M < `std::isize::MAX` should not cause any errors in this library, but
49 * attempting to expand their size further may cause panics and or errors. At the time of
50 * writing it is no longer possible to construct or use matrices where the product of their
51 * number of rows and columns exceed `std::isize::MAX`, but some constructor methods may be used
52 * to attempt this. Concerned readers should note that on a 64 bit computer this maximum
53 * value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
54 *
55 * # Matrix layout and iterator performance
56 *
57 * [See iterators submodule for Matrix layout and iterator performance](iterators#matrix-layout-and-iterator-performance)
58 *
59 * # Matrix operations
60 *
61 * [See operations submodule](operations)
62 */
63#[derive(Debug)]
64#[cfg_attr(feature = "serde", derive(Serialize))]
65pub struct Matrix<T> {
66    data: Vec<T>,
67    rows: Row,
68    columns: Column,
69}
70
71/// The maximum row and column lengths are usize, due to the internal storage being backed by Vec
72pub type Row = usize;
73/// The maximum row and column lengths are usize, due to the internal storage being backed by Vec
74pub type Column = usize;
75
76/**
77 * Methods for matrices of any type, including non numerical types such as bool.
78 */
79impl<T> Matrix<T> {
80    /**
81     * Creates a 1x1 matrix from some scalar
82     */
83    pub fn from_scalar(value: T) -> Matrix<T> {
84        Matrix {
85            data: vec![value],
86            rows: 1,
87            columns: 1,
88        }
89    }
90
91    /**
92     * Creates a row vector (1xN) from a list
93     *
94     * # Panics
95     *
96     * Panics if no values are provided. Note: this method erroneously did not validate its inputs
97     * in Easy ML versions up to and including 1.7.0
98     */
99    #[track_caller]
100    pub fn row(values: Vec<T>) -> Matrix<T> {
101        assert!(!values.is_empty(), "No values provided");
102        Matrix {
103            columns: values.len(),
104            data: values,
105            rows: 1,
106        }
107    }
108
109    /**
110     * Creates a column vector (Nx1) from a list
111     *
112     * # Panics
113     *
114     * Panics if no values are provided. Note: this method erroneously did not validate its inputs
115     * in Easy ML versions up to and including 1.7.0
116     */
117    #[track_caller]
118    pub fn column(values: Vec<T>) -> Matrix<T> {
119        assert!(!values.is_empty(), "No values provided");
120        Matrix {
121            rows: values.len(),
122            data: values,
123            columns: 1,
124        }
125    }
126
127    /**
128     * Creates a matrix from a nested array of values, each inner vector
129     * being a row, and hence the outer vector containing all rows in sequence, the
130     * same way as when writing matrices in mathematics.
131     *
132     * Example of a 2 x 3 matrix in both notations:
133     * ```ignore
134     *   [
135     *      1, 2, 4
136     *      8, 9, 3
137     *   ]
138     * ```
139     * ```
140     * use easy_ml::matrices::Matrix;
141     * Matrix::from(vec![
142     *     vec![ 1, 2, 4 ],
143     *     vec![ 8, 9, 3 ]
144     * ]);
145     * ```
146     *
147     * # Panics
148     *
149     * Panics if the input is jagged or rows or column length is 0.
150     */
151    #[track_caller]
152    pub fn from(mut values: Vec<Vec<T>>) -> Matrix<T> {
153        assert!(!values.is_empty(), "No rows defined");
154        // check length of first row is > 1
155        assert!(!values[0].is_empty(), "No column defined");
156        // check length of each row is the same
157        assert!(
158            values.iter().map(|x| x.len()).all(|x| x == values[0].len()),
159            "Inconsistent size"
160        );
161        // flatten the data into a row major layout
162        let rows = values.len();
163        let columns = values[0].len();
164        let mut data = Vec::with_capacity(rows * columns);
165        let mut value_stream = values.drain(..);
166        for _ in 0..rows {
167            let mut value_row_stream = value_stream.next().unwrap();
168            let mut row_of_values = value_row_stream.drain(..);
169            for _ in 0..columns {
170                data.push(row_of_values.next().unwrap());
171            }
172        }
173        Matrix {
174            data,
175            rows,
176            columns,
177        }
178    }
179
180    /**
181     * Creates a matrix with the specified size from a row major vec of data.
182     * The length of the vec must match the size of the matrix or the constructor
183     * will panic.
184     *
185     * Example of a 2 x 3 matrix in both notations:
186     * ```ignore
187     *   [
188     *      1, 2, 4
189     *      8, 9, 3
190     *   ]
191     * ```
192     * ```
193     * use easy_ml::matrices::Matrix;
194     * Matrix::from_flat_row_major((2, 3), vec![
195     *     1, 2, 4,
196     *     8, 9, 3
197     * ]);
198     * ```
199     *
200     * This method is more efficient than [`Matrix::from`](Matrix::from())
201     * but requires specifying the size explicitly and manually keeping track of where rows
202     * start and stop.
203     *
204     * # Panics
205     *
206     * Panics if the length of the vec does not match the size of the matrix, or no values are
207     * provided. Note: this method erroneously did not validate its inputs were not empty in
208     * Easy ML versions up to and including 1.7.0
209     */
210    #[track_caller]
211    pub fn from_flat_row_major(size: (Row, Column), values: Vec<T>) -> Matrix<T> {
212        assert!(
213            size.0 * size.1 == values.len(),
214            "Inconsistent size, attempted to construct a {}x{} matrix but provided with {} elements.",
215            size.0,
216            size.1,
217            values.len()
218        );
219        assert!(!values.is_empty(), "No values provided");
220        Matrix {
221            data: values,
222            rows: size.0,
223            columns: size.1,
224        }
225    }
226
227    /**
228     * Creates a matrix with the specified size initalised from a function.
229     *
230     * ```
231     * use easy_ml::matrices::Matrix;
232     * let matrix = Matrix::from_fn((4, 4), |(r, c)| r * c);
233     * assert_eq!(
234     *     matrix,
235     *     Matrix::from(vec![
236     *         vec![ 0, 0, 0, 0 ],
237     *         vec![ 0, 1, 2, 3 ],
238     *         vec![ 0, 2, 4, 6 ],
239     *         vec![ 0, 3, 6, 9 ],
240     *     ])
241     * );
242     * ```
243     *
244     * # Panics
245     *
246     * Panics if the size has 0 rows or columns.
247     */
248    #[track_caller]
249    pub fn from_fn<F>(size: (Row, Column), mut producer: F) -> Matrix<T>
250    where
251        F: FnMut((Row, Column)) -> T,
252    {
253        use crate::tensors::indexing::ShapeIterator;
254        let length = size.0 * size.1;
255        let mut data = Vec::with_capacity(length);
256        let iterator = ShapeIterator::from([("row", size.0), ("column", size.1)]);
257        for [r, c] in iterator {
258            data.push(producer((r, c)));
259        }
260        Matrix::from_flat_row_major(size, data)
261    }
262
263    #[deprecated(
264        since = "1.1.0",
265        note = "Incorrect use of terminology, a unit matrix is another term for an identity matrix, please use `from_scalar` instead"
266    )]
267    pub fn unit(value: T) -> Matrix<T> {
268        Matrix::from_scalar(value)
269    }
270
271    /**
272     * Returns the dimensionality of this matrix in Row, Column format
273     */
274    pub fn size(&self) -> (Row, Column) {
275        (self.rows, self.columns)
276    }
277
278    /**
279     * Gets the number of rows in this matrix.
280     */
281    pub fn rows(&self) -> Row {
282        self.rows
283    }
284
285    /**
286     * Gets the number of columns in this matrix.
287     */
288    pub fn columns(&self) -> Column {
289        self.columns
290    }
291
292    /**
293     * Matrix data is stored as row major, so each row is stored as
294     * adjacent items going through the different columns. Therefore,
295     * to index this flattened representation we jump down in row sized
296     * blocks to reach the correct row, and then jump further equal to
297     * the column. The confusing thing is that the number of columns
298     * this matrix has is the length of each of the rows in this matrix,
299     * and vice versa.
300     */
301    fn get_index(&self, row: Row, column: Column) -> usize {
302        column + (row * self.columns())
303    }
304
305    /**
306     * The reverse of [get_index], converts from the flattened storage
307     * in memory into the row and column to index at this position.
308     *
309     * Matrix data is stored as row major, so each multiple of the number
310     * of columns starts a new row, and each index modulo the columns
311     * gives the column.
312     */
313    #[allow(dead_code)]
314    fn get_row_column(&self, index: usize) -> (Row, Column) {
315        (index / self.columns(), index % self.columns())
316    }
317
318    /**
319     * Gets a reference to the value at this row and column. Rows and Columns are 0 indexed.
320     *
321     * # Panics
322     *
323     * Panics if the index is out of range.
324     */
325    #[track_caller]
326    pub fn get_reference(&self, row: Row, column: Column) -> &T {
327        assert!(row < self.rows(), "Row out of index");
328        assert!(column < self.columns(), "Column out of index");
329        &self.data[self.get_index(row, column)]
330    }
331
332    /**
333     * Gets a mutable reference to the value at this row and column.
334     * Rows and Columns are 0 indexed.
335     *
336     * # Panics
337     *
338     * Panics if the index is out of range.
339     */
340    #[track_caller]
341    pub fn get_reference_mut(&mut self, row: Row, column: Column) -> &mut T {
342        assert!(row < self.rows(), "Row out of index");
343        assert!(column < self.columns(), "Column out of index");
344        let index = self.get_index(row, column);
345        // borrow for get_index ends
346        &mut self.data[index]
347    }
348
349    /**
350     * Not public API because don't want to name clash with the method on MatrixRef
351     * that calls this.
352     */
353    pub(crate) fn _try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
354        if row < self.rows() && column < self.columns() {
355            Some(&self.data[self.get_index(row, column)])
356        } else {
357            None
358        }
359    }
360
361    /**
362     * Not public API because don't want to name clash with the method on MatrixRef
363     * that calls this.
364     */
365    pub(crate) unsafe fn _get_reference_unchecked(&self, row: Row, column: Column) -> &T {
366        unsafe { self.data.get_unchecked(self.get_index(row, column)) }
367    }
368
369    /**
370     * Sets a new value to this row and column. Rows and Columns are 0 indexed.
371     *
372     * # Panics
373     *
374     * Panics if the index is out of range.
375     */
376    #[track_caller]
377    pub fn set(&mut self, row: Row, column: Column, value: T) {
378        assert!(row < self.rows(), "Row out of index");
379        assert!(column < self.columns(), "Column out of index");
380        let index = self.get_index(row, column);
381        // borrow for get_index ends
382        self.data[index] = value;
383    }
384
385    /**
386     * Not public API because don't want to name clash with the method on MatrixMut
387     * that calls this.
388     */
389    pub(crate) fn _try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
390        if row < self.rows() && column < self.columns() {
391            let index = self.get_index(row, column);
392            // borrow for get_index ends
393            Some(&mut self.data[index])
394        } else {
395            None
396        }
397    }
398
399    /**
400     * Not public API because don't want to name clash with the method on MatrixMut
401     * that calls this.
402     */
403    pub(crate) unsafe fn _get_reference_unchecked_mut(
404        &mut self,
405        row: Row,
406        column: Column,
407    ) -> &mut T {
408        unsafe {
409            let index = self.get_index(row, column);
410            // borrow for get_index ends
411            self.data.get_unchecked_mut(index)
412        }
413    }
414
415    /**
416     * Removes a row from this Matrix, shifting all other rows to the left.
417     * Rows are 0 indexed.
418     *
419     * # Panics
420     *
421     * This will panic if the row does not exist or the matrix only has one row.
422     */
423    #[track_caller]
424    pub fn remove_row(&mut self, row: Row) {
425        assert!(self.rows() > 1);
426        let mut r = 0;
427        let mut c = 0;
428        // drop the values at the specified row
429        let columns = self.columns();
430        self.data.retain(|_| {
431            let keep = r != row;
432            if c < (columns - 1) {
433                c += 1;
434            } else {
435                r += 1;
436                c = 0;
437            }
438            keep
439        });
440        self.rows -= 1;
441    }
442
443    /**
444     * Removes a column from this Matrix, shifting all other columns to the left.
445     * Columns are 0 indexed.
446     *
447     * # Panics
448     *
449     * This will panic if the column does not exist or the matrix only has one column.
450     */
451    #[track_caller]
452    pub fn remove_column(&mut self, column: Column) {
453        assert!(self.columns() > 1);
454        let mut r = 0;
455        let mut c = 0;
456        // drop the values at the specified column
457        let columns = self.columns();
458        self.data.retain(|_| {
459            let keep = c != column;
460            if c < (columns - 1) {
461                c += 1;
462            } else {
463                r += 1;
464                c = 0;
465            }
466            keep
467        });
468        self.columns -= 1;
469    }
470
471    /**
472     * Returns an iterator over references to a column vector in this matrix.
473     * Columns are 0 indexed.
474     *
475     * # Panics
476     *
477     * Panics if the column does not exist in this matrix.
478     */
479    #[track_caller]
480    pub fn column_reference_iter(&self, column: Column) -> ColumnReferenceIterator<T> {
481        ColumnReferenceIterator::new(self, column)
482    }
483
484    /**
485     * Returns an iterator over references to a row vector in this matrix.
486     * Rows are 0 indexed.
487     *
488     * # Panics
489     *
490     * Panics if the row does not exist in this matrix.
491     */
492    #[track_caller]
493    pub fn row_reference_iter(&self, row: Row) -> RowReferenceIterator<T> {
494        RowReferenceIterator::new(self, row)
495    }
496
497    /**
498     * Returns an iterator over mutable references to a column vector in this matrix.
499     * Columns are 0 indexed.
500     *
501     * # Panics
502     *
503     * Panics if the column does not exist in this matrix.
504     */
505    #[track_caller]
506    pub fn column_reference_mut_iter(&mut self, column: Column) -> ColumnReferenceMutIterator<T> {
507        ColumnReferenceMutIterator::new(self, column)
508    }
509
510    /**
511     * Returns an iterator over mutable references to a row vector in this matrix.
512     * Rows are 0 indexed.
513     *
514     * # Panics
515     *
516     * Panics if the row does not exist in this matrix.
517     */
518    #[track_caller]
519    pub fn row_reference_mut_iter(&mut self, row: Row) -> RowReferenceMutIterator<T> {
520        RowReferenceMutIterator::new(self, row)
521    }
522
523    /**
524     * Returns a column major iterator over references to all values in this matrix,
525     * proceeding through each column in order.
526     */
527    pub fn column_major_reference_iter(&self) -> ColumnMajorReferenceIterator<T> {
528        ColumnMajorReferenceIterator::new(self)
529    }
530
531    /**
532     * Returns a row major iterator over references to all values in this matrix,
533     * proceeding through each row in order.
534     */
535    pub fn row_major_reference_iter(&self) -> RowMajorReferenceIterator<T> {
536        RowMajorReferenceIterator::new(self)
537    }
538
539    // Non public row major reference iterator since we don't want to expose our implementation
540    // details to public API since then we could never change them.
541    pub(crate) fn direct_row_major_reference_iter(&self) -> std::slice::Iter<T> {
542        self.data.iter()
543    }
544
545    // Non public row major reference iterator since we don't want to expose our implementation
546    // details to public API since then we could never change them.
547    pub(crate) fn direct_row_major_reference_iter_mut(&mut self) -> std::slice::IterMut<T> {
548        self.data.iter_mut()
549    }
550
551    /**
552     * Returns a column major iterator over mutable references to all values in this matrix,
553     * proceeding through each column in order.
554     */
555    pub fn column_major_reference_mut_iter(&mut self) -> ColumnMajorReferenceMutIterator<T> {
556        ColumnMajorReferenceMutIterator::new(self)
557    }
558
559    /**
560     * Returns a row major iterator over mutable references to all values in this matrix,
561     * proceeding through each row in order.
562     */
563    pub fn row_major_reference_mut_iter(&mut self) -> RowMajorReferenceMutIterator<T> {
564        RowMajorReferenceMutIterator::new(self)
565    }
566
567    /**
568     * Creates a column major iterator over all values in this matrix,
569     * proceeding through each column in order.
570     */
571    pub fn column_major_owned_iter(self) -> ColumnMajorOwnedIterator<T>
572    where
573        T: Default,
574    {
575        ColumnMajorOwnedIterator::new(self)
576    }
577
578    /**
579     * Creates a row major iterator over all values in this matrix,
580     * proceeding through each row in order.
581     */
582    pub fn row_major_owned_iter(self) -> RowMajorOwnedIterator<T>
583    where
584        T: Default,
585    {
586        RowMajorOwnedIterator::new(self)
587    }
588
589    /**
590     * Returns an iterator over references to the main diagonal in this matrix.
591     */
592    pub fn diagonal_reference_iter(&self) -> DiagonalReferenceIterator<T> {
593        DiagonalReferenceIterator::new(self)
594    }
595
596    /**
597     * Returns an iterator over mutable references to the main diagonal in this matrix.
598     */
599    pub fn diagonal_reference_mut_iter(&mut self) -> DiagonalReferenceMutIterator<T> {
600        DiagonalReferenceMutIterator::new(self)
601    }
602
603    /**
604     * Shrinks this matrix down from its current MxN size down to
605     * some new size OxP where O and P are determined by the kind of
606     * slice given and 1 <= O <= M and 1 <= P <= N.
607     *
608     * Only rows and columns specified by the slice will be retained, so for
609     * instance if the Slice is constructed by
610     * `Slice2D::new().rows(Slice::Range(0..2)).columns(Slice::Range(0..3))` then the
611     * modified matrix will be no bigger than 2x3 and contain up to the first two
612     * rows and first three columns that it previously had.
613     *
614     * See [Slice](slices::Slice) for constructing slices.
615     *
616     * # Panics
617     *
618     * This function will panic if the slice would delete all rows or all columns
619     * from this matrix, ie the resulting matrix must be at least 1x1.
620     */
621    #[track_caller]
622    pub fn retain_mut(&mut self, slice: Slice2D) {
623        let mut r = 0;
624        let mut c = 0;
625        // drop the values rejected by the slice
626        let columns = self.columns();
627        self.data.retain(|_| {
628            let keep = slice.accepts(r, c);
629            if c < (columns - 1) {
630                c += 1;
631            } else {
632                r += 1;
633                c = 0;
634            }
635            keep
636        });
637        // work out the resulting size of this matrix by using the non
638        // public fields of the Slice2D to handle each row and column
639        // seperately.
640        let remaining_rows = {
641            let mut accepted = 0;
642            for i in 0..self.rows() {
643                if slice.rows.accepts(i) {
644                    accepted += 1;
645                }
646            }
647            accepted
648        };
649        let remaining_columns = {
650            let mut accepted = 0;
651            for i in 0..self.columns() {
652                if slice.columns.accepts(i) {
653                    accepted += 1;
654                }
655            }
656            accepted
657        };
658        assert!(
659            remaining_rows > 0,
660            "Provided slice must leave at least 1 row in the retained matrix"
661        );
662        assert!(
663            remaining_columns > 0,
664            "Provided slice must leave at least 1 column in the retained matrix"
665        );
666        assert!(
667            !self.data.is_empty(),
668            "Provided slice must leave at least 1 row and 1 column in the retained matrix"
669        );
670        self.rows = remaining_rows;
671        self.columns = remaining_columns
672        // By construction jagged slices should be impossible, if this
673        // invariant later changes by accident it would be possible to break the
674        // rectangle shape invariant on a matrix object
675        // As Slice2D should prevent the construction of jagged slices no
676        // check is here to detect if all rows are still the same length
677    }
678
679    /**
680     * Consumes a 1x1 matrix and converts it into a scalar without copying the data.
681     *
682     * # Example
683     *
684     * ```
685     * use easy_ml::matrices::Matrix;
686     * # fn main() -> Result<(), Box<dyn std::error::Error>> {
687     * let x = Matrix::column(vec![ 1.0, 2.0, 3.0 ]);
688     * let sum_of_squares: f64 = (x.transpose() * x).try_into_scalar()?;
689     * # Ok(())
690     * # }
691     * ```
692     */
693    pub fn try_into_scalar(self) -> Result<T, ScalarConversionError> {
694        if self.size() == (1, 1) {
695            Ok(self.data.into_iter().next().unwrap())
696        } else {
697            Err(ScalarConversionError {})
698        }
699    }
700
701    /**
702     * Partition a matrix into an arbitary number of non overlapping parts.
703     *
704     * **This function is much like a hammer you should be careful to not overuse. If you don't need
705     * to mutate the parts of the matrix data individually it will be much easier and less error
706     * prone to create immutable views into the matrix using [MatrixRange] instead.**
707     *
708     * Parts are returned in row major order, forming a grid of slices into the Matrix data that
709     * can be mutated independently.
710     *
711     * # Panics
712     *
713     * Panics if any row or column index is greater than the number of rows or columns in the
714     * matrix. Each list of row partitions and column partitions must also be in ascending order.
715     *
716     * # Further Info
717     *
718     * The partitions form the boundries between each slice of matrix data. Hence, for each
719     * dimension, each partition may range between 0 and the length of the dimension inclusive.
720     *
721     * For one dimension of length 5, you can supply 0 up to 6 partitions,
722     * `[0,1,2,3,4,5]` would split that dimension into 7, 0 to 0, 0 to 1, 1 to 2,
723     * 2 to 3, 3 to 4, 4 to 5 and 5 to 5. 0 to 0 and 5 to 5 would of course be empty and the
724     * 5 parts in between would each be of length 1 along that dimension.
725     * `[2,4]` would instead split that dimension into three parts of 0 to 2, 2 to 4, and 4 to 5.
726     * `[]` would not split that dimension at all, and give a single part of 0 to 5.
727     *
728     * `partition` does this along both dimensions, and returns the parts in row major order, so
729     * you will receive a list of R+1 * C+1 length where R is the length of the row partitions
730     * provided and C is the length of the column partitions provided. If you just want to split
731     * a matrix into a 2x2 grid see [`partition_quadrants`](Matrix::partition_quadrants) which
732     * provides a dedicated API with more ergonomics for extracting the parts.
733     */
734    #[track_caller]
735    pub fn partition(
736        &mut self,
737        row_partitions: &[Row],
738        column_partitions: &[Column],
739    ) -> Vec<MatrixView<T, MatrixPart<T>>> {
740        let rows = self.rows();
741        let columns = self.columns();
742        fn check_axis(partitions: &[usize], length: usize) {
743            let mut previous: Option<usize> = None;
744            for &index in partitions {
745                assert!(index <= length);
746                previous = match previous {
747                    None => Some(index),
748                    Some(i) => {
749                        assert!(index > i, "{:?} must be ascending", partitions);
750                        Some(i)
751                    }
752                }
753            }
754        }
755        check_axis(row_partitions, rows);
756        check_axis(column_partitions, columns);
757
758        // There will be one more slice than partitions, since partitions are the boundries
759        // between slices.
760        let row_slices = row_partitions.len() + 1;
761        let column_slices = column_partitions.len() + 1;
762        let total_slices = row_slices * column_slices;
763        let mut slices: Vec<Vec<&mut [T]>> = Vec::with_capacity(total_slices);
764        let (_, mut data) = self.data.split_at_mut(0);
765
766        let mut index = 0;
767        for r in 0..row_slices {
768            let row_index = row_partitions.get(r).cloned().unwrap_or(rows);
769            // Determine how many rows of our matrix we need for the next set of row slices
770            let rows_included = row_index - index;
771            for _ in 0..column_slices {
772                slices.push(Vec::with_capacity(rows_included));
773            }
774            index = row_index;
775
776            for _ in 0..rows_included {
777                // Partition the next row of our matrix along the columns
778                let mut index = 0;
779                for c in 0..column_slices {
780                    let column_index = column_partitions.get(c).cloned().unwrap_or(columns);
781                    let columns_included = column_index - index;
782                    index = column_index;
783                    // Split off as many elements as included in this column slice
784                    let (slice, rest) = data.split_at_mut(columns_included);
785                    // Insert the slice into the slices, we'll push `rows_included` times into
786                    // each slice Vec.
787                    slices[(r * column_slices) + c].push(slice);
788                    data = rest;
789                }
790            }
791        }
792        // rest is now empty, so we can ignore it.
793
794        slices
795            .into_iter()
796            .map(|slices| {
797                let rows = slices.len();
798                let columns = slices.first().map(|columns| columns.len()).unwrap_or(0);
799                if columns == 0 {
800                    // We may have allocated N rows but if each column in that row has no size
801                    // our actual size is 0x0
802                    MatrixView::from(MatrixPart::new(slices, 0, 0))
803                } else {
804                    MatrixView::from(MatrixPart::new(slices, rows, columns))
805                }
806            })
807            .collect()
808    }
809
810    /**
811     * Partition a matrix into 4 non overlapping quadrants. Top left starts at 0,0 until
812     * exclusive of row and column, bottom right starts at row and column to the end of the matrix.
813     *
814     * # Panics
815     *
816     * Panics if the row or column are greater than the number of rows or columns in the matrix.
817     *
818     * # Examples
819     *
820     * ```
821     * use easy_ml::matrices::Matrix;
822     * let mut matrix = Matrix::from(vec![
823     *     vec![ 0, 1, 2 ],
824     *     vec![ 3, 4, 5 ],
825     *     vec![ 6, 7, 8 ]
826     * ]);
827     * // Split the matrix at the second row and first column giving 2x1, 2x2, 1x1 and 2x1
828     * // quadrants.
829     * // 0 | 1 2
830     * // 3 | 4 5
831     * // -------
832     * // 6 | 7 8
833     * let mut parts = matrix.partition_quadrants(2, 1);
834     * assert_eq!(parts.top_left, Matrix::column(vec![ 0, 3 ]));
835     * assert_eq!(parts.top_right, Matrix::from(vec![vec![ 1, 2 ], vec![ 4, 5 ]]));
836     * assert_eq!(parts.bottom_left, Matrix::column(vec![ 6 ]));
837     * assert_eq!(parts.bottom_right, Matrix::row(vec![ 7, 8 ]));
838     * // Modify the matrix data independently without worrying about the borrow checker
839     * parts.top_right.map_mut(|x| x + 10);
840     * parts.bottom_left.map_mut(|x| x - 10);
841     * // Drop MatrixQuadrants so we can use the matrix directly again
842     * std::mem::drop(parts);
843     * assert_eq!(matrix, Matrix::from(vec![
844     *     vec![ 0, 11, 12 ],
845     *     vec![ 3, 14, 15 ],
846     *     vec![ -4, 7, 8 ]
847     * ]));
848     * ```
849     */
850    #[track_caller]
851    #[allow(clippy::needless_lifetimes)] // false positive?
852    pub fn partition_quadrants<'a>(
853        &'a mut self,
854        row: Row,
855        column: Column,
856    ) -> MatrixQuadrants<'a, T> {
857        let mut parts = self.partition(&[row], &[column]).into_iter();
858        // We know there will be exactly 4 parts returned by the partition since we provided
859        // 1 row and 1 column to partition ourself into 4 with.
860        MatrixQuadrants {
861            top_left: parts.next().unwrap(),
862            top_right: parts.next().unwrap(),
863            bottom_left: parts.next().unwrap(),
864            bottom_right: parts.next().unwrap(),
865        }
866    }
867
868    /**
869     * Returns a MatrixView giving a view of only the data within the row and column
870     * [IndexRange]s.
871     *
872     * This is a shorthand for constructing the MatrixView from this Matrix.
873     *
874     * ```
875     * use easy_ml::matrices::Matrix;
876     * use easy_ml::matrices::views::{MatrixView, MatrixRange, IndexRange};
877     * let ab = Matrix::from(vec![
878     *     vec![ 0, 1, 2, 0 ],
879     *     vec![ 3, 4, 5, 1 ]
880     * ]);
881     * let shorter = ab.range(0..2, 1..3);
882     * assert_eq!(
883     *     shorter,
884     *     Matrix::from(vec![
885     *        vec![ 1, 2 ],
886     *        vec![ 4, 5 ]
887     *     ])
888     * );
889     * ```
890     */
891    pub fn range<R>(&self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, &Matrix<T>>>
892    where
893        R: Into<IndexRange>,
894    {
895        MatrixView::from(MatrixRange::from(self, rows, columns))
896    }
897
898    /**
899     * Returns a MatrixView giving a view of only the data within the row and column
900     * [IndexRange]s. The MatrixRange mutably borrows this Matrix, and can
901     * therefore mutate it.
902     *
903     * This is a shorthand for constructing the MatrixView from this Matrix.
904     */
905    pub fn range_mut<R>(
906        &mut self,
907        rows: R,
908        columns: R,
909    ) -> MatrixView<T, MatrixRange<T, &mut Matrix<T>>>
910    where
911        R: Into<IndexRange>,
912    {
913        MatrixView::from(MatrixRange::from(self, rows, columns))
914    }
915
916    /**
917     * Returns a MatrixView giving a view of only the data within the row and column
918     * [IndexRange]s. The MatrixRange takes ownership of this Matrix, and can
919     * therefore mutate it.
920     *
921     * This is a shorthand for constructing the MatrixView from this Matrix.
922     */
923    pub fn range_owned<R>(self, rows: R, columns: R) -> MatrixView<T, MatrixRange<T, Matrix<T>>>
924    where
925        R: Into<IndexRange>,
926    {
927        MatrixView::from(MatrixRange::from(self, rows, columns))
928    }
929
930    /**
931     * Returns a MatrixView giving a view of only the data outside the row and column
932     * [IndexRange]s.
933     *
934     * This is a shorthand for constructing the MatrixView from this Matrix.
935     *
936     * ```
937     * use easy_ml::matrices::Matrix;
938     * use easy_ml::matrices::views::{MatrixView, MatrixMask, IndexRange};
939     * let ab = Matrix::from(vec![
940     *     vec![ 0, 1, 2, 0 ],
941     *     vec![ 3, 4, 5, 1 ]
942     * ]);
943     * let shorter = ab.mask(0..1, 1..3);
944     * assert_eq!(
945     *     shorter,
946     *     Matrix::from(vec![
947     *        vec![ 3, 1 ]
948     *     ])
949     * );
950     * ```
951     */
952    pub fn mask<R>(&self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, &Matrix<T>>>
953    where
954        R: Into<IndexRange>,
955    {
956        MatrixView::from(MatrixMask::from(self, rows, columns))
957    }
958
959    /**
960     * Returns a MatrixView giving a view of only the data outside the row and column
961     * [IndexRange]s. The MatrixMask mutably borrows this Matrix, and can
962     * therefore mutate it.
963     *
964     * This is a shorthand for constructing the MatrixView from this Matrix.
965     */
966    pub fn mask_mut<R>(
967        &mut self,
968        rows: R,
969        columns: R,
970    ) -> MatrixView<T, MatrixMask<T, &mut Matrix<T>>>
971    where
972        R: Into<IndexRange>,
973    {
974        MatrixView::from(MatrixMask::from(self, rows, columns))
975    }
976
977    /**
978     * Returns a MatrixView giving a view of only the data outside the row and column
979     * [IndexRange]s. The MatrixMask takes ownership of this Matrix, and can
980     * therefore mutate it.
981     *
982     * This is a shorthand for constructing the MatrixView from this Matrix.
983     */
984    pub fn mask_owned<R>(self, rows: R, columns: R) -> MatrixView<T, MatrixMask<T, Matrix<T>>>
985    where
986        R: Into<IndexRange>,
987    {
988        MatrixView::from(MatrixMask::from(self, rows, columns))
989    }
990
991    /**
992     * Returns a MatrixView with the rows and columns specified reversed in iteration
993     * order. The data of this matrix and the dimension lengths remain unchanged.
994     *
995     * This is a shorthand for constructing the MatrixView from this Matrix.
996     *
997     * ```
998     * use easy_ml::matrices::Matrix;
999     * use easy_ml::matrices::views::{MatrixView, MatrixReverse, Reverse};
1000     * let ab = Matrix::from(vec![
1001     *     vec![ 0, 1, 2 ],
1002     *     vec![ 3, 4, 5 ]
1003     * ]);
1004     * let reversed = ab.reverse(Reverse { rows: true, ..Default::default() });
1005     * let also_reversed = MatrixView::from(
1006     *     MatrixReverse::from(&ab, Reverse { rows: true, columns: false })
1007     * );
1008     * assert_eq!(reversed, also_reversed);
1009     * assert_eq!(
1010     *     reversed,
1011     *     Matrix::from(vec![
1012     *         vec![ 3, 4, 5 ],
1013     *         vec![ 0, 1, 2 ]
1014     *     ])
1015     * );
1016     * ```
1017     */
1018    pub fn reverse(&self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, &Matrix<T>>> {
1019        MatrixView::from(MatrixReverse::from(self, reverse))
1020    }
1021
1022    /**
1023     * Returns a MatrixView with the rows and columns specified reversed in iteration
1024     * order. The data of this matrix and the dimension lengths remain unchanged. The MatrixReverse
1025     * mutably borrows this Matrix, and can therefore mutate it
1026     *
1027     * This is a shorthand for constructing the MatrixView from this Matrix.
1028     */
1029    pub fn reverse_mut(
1030        &mut self,
1031        reverse: Reverse,
1032    ) -> MatrixView<T, MatrixReverse<T, &mut Matrix<T>>> {
1033        MatrixView::from(MatrixReverse::from(self, reverse))
1034    }
1035
1036    /**
1037     * Returns a MatrixView with the rows and columns specified reversed in iteration
1038     * order. The data of this matrix and the dimension lengths remain unchanged. The MatrixReverse
1039     * takes ownership of this Matrix, and can therefore mutate it
1040     *
1041     * This is a shorthand for constructing the MatrixView from this Matrix.
1042     */
1043    pub fn reverse_owned(self, reverse: Reverse) -> MatrixView<T, MatrixReverse<T, Matrix<T>>> {
1044        MatrixView::from(MatrixReverse::from(self, reverse))
1045    }
1046
1047    /**
1048     * Converts this Matrix into a 2 dimensional Tensor with the provided dimension names.
1049     *
1050     * This is a wrapper around the `TryFrom<(Matrix<T>, [Dimension; 2])>` implementation.
1051     *
1052     * The Tensor will have the data in the same order, a shape with lengths of `self.rows()` then
1053     * `self.columns()` and the provided dimension names respectively.
1054     *
1055     * Result::Err is returned if the `rows` and `columns` dimension names are the same.
1056     */
1057    pub fn into_tensor(
1058        self,
1059        rows: crate::tensors::Dimension,
1060        columns: crate::tensors::Dimension,
1061    ) -> Result<crate::tensors::Tensor<T, 2>, crate::tensors::InvalidShapeError<2>> {
1062        (self, [rows, columns]).try_into()
1063    }
1064}
1065
1066/**
1067 * Methods for matrices with types that can be copied, but still not neccessarily numerical.
1068 */
1069impl<T: Clone> Matrix<T> {
1070    /**
1071     * Computes and returns the transpose of this matrix
1072     *
1073     * ```
1074     * use easy_ml::matrices::Matrix;
1075     * let x = Matrix::from(vec![
1076     *    vec![ 1, 2 ],
1077     *    vec![ 3, 4 ]]);
1078     * let y = Matrix::from(vec![
1079     *    vec![ 1, 3 ],
1080     *    vec![ 2, 4 ]]);
1081     * assert_eq!(x.transpose(), y);
1082     * ```
1083     */
1084    pub fn transpose(&self) -> Matrix<T> {
1085        Matrix::from_fn((self.columns(), self.rows()), |(column, row)| {
1086            self.get(row, column)
1087        })
1088    }
1089
1090    /**
1091     * Transposes the matrix in place (if it is square).
1092     *
1093     * ```
1094     * use easy_ml::matrices::Matrix;
1095     * let mut x = Matrix::from(vec![
1096     *    vec![ 1, 2 ],
1097     *    vec![ 3, 4 ]]);
1098     * x.transpose_mut();
1099     * let y = Matrix::from(vec![
1100     *    vec![ 1, 3 ],
1101     *    vec![ 2, 4 ]]);
1102     * assert_eq!(x, y);
1103     * ```
1104     *
1105     * Note: None square matrices were erroneously not supported in previous versions (<=1.8.0) and
1106     * could be incorrectly mutated. This method will now correctly transpose non square matrices
1107     * by not attempting to transpose them in place.
1108     */
1109    pub fn transpose_mut(&mut self) {
1110        if self.rows() != self.columns() {
1111            let transposed = self.transpose();
1112            self.data = transposed.data;
1113            self.rows = transposed.rows;
1114            self.columns = transposed.columns;
1115        } else {
1116            for i in 0..self.rows() {
1117                for j in 0..self.columns() {
1118                    if i > j {
1119                        continue;
1120                    }
1121                    let temp = self.get(i, j);
1122                    self.set(i, j, self.get(j, i));
1123                    self.set(j, i, temp);
1124                }
1125            }
1126        }
1127    }
1128
1129    /**
1130     * Returns an iterator over a column vector in this matrix. Columns are 0 indexed.
1131     *
1132     * If you have a matrix such as:
1133     * ```ignore
1134     * [
1135     *    1, 2, 3
1136     *    4, 5, 6
1137     *    7, 8, 9
1138     * ]
1139     * ```
1140     * then a column of 0, 1, and 2 will yield [1, 4, 7], [2, 5, 8] and [3, 6, 9]
1141     * respectively. If you do not need to copy the elements use
1142     * [`column_reference_iter`](Matrix::column_reference_iter) instead.
1143     *
1144     * # Panics
1145     *
1146     * Panics if the column does not exist in this matrix.
1147     */
1148    #[track_caller]
1149    pub fn column_iter(&self, column: Column) -> ColumnIterator<T> {
1150        ColumnIterator::new(self, column)
1151    }
1152
1153    /**
1154     * Returns an iterator over a row vector in this matrix. Rows are 0 indexed.
1155     *
1156     * If you have a matrix such as:
1157     * ```ignore
1158     * [
1159     *    1, 2, 3
1160     *    4, 5, 6
1161     *    7, 8, 9
1162     * ]
1163     * ```
1164     * then a row of 0, 1, and 2 will yield [1, 2, 3], [4, 5, 6] and [7, 8, 9]
1165     * respectively. If you do not need to copy the elements use
1166     * [`row_reference_iter`](Matrix::row_reference_iter) instead.
1167     *
1168     * # Panics
1169     *
1170     * Panics if the row does not exist in this matrix.
1171     */
1172    #[track_caller]
1173    pub fn row_iter(&self, row: Row) -> RowIterator<T> {
1174        RowIterator::new(self, row)
1175    }
1176
1177    /**
1178     * Returns a column major iterator over all values in this matrix, proceeding through each
1179     * column in order.
1180     *
1181     * If you have a matrix such as:
1182     * ```ignore
1183     * [
1184     *    1, 2
1185     *    3, 4
1186     * ]
1187     * ```
1188     * then the iterator will yield [1, 3, 2, 4]. If you do not need to copy the
1189     * elements use [`column_major_reference_iter`](Matrix::column_major_reference_iter) instead.
1190     */
1191    pub fn column_major_iter(&self) -> ColumnMajorIterator<T> {
1192        ColumnMajorIterator::new(self)
1193    }
1194
1195    /**
1196     * Returns a row major iterator over all values in this matrix, proceeding through each
1197     * row in order.
1198     *
1199     * If you have a matrix such as:
1200     * ```ignore
1201     * [
1202     *    1, 2
1203     *    3, 4
1204     * ]
1205     * ```
1206     * then the iterator will yield [1, 2, 3, 4]. If you do not need to copy the
1207     * elements use [`row_major_reference_iter`](Matrix::row_major_reference_iter) instead.
1208     */
1209    pub fn row_major_iter(&self) -> RowMajorIterator<T> {
1210        RowMajorIterator::new(self)
1211    }
1212
1213    /**
1214     * Returns a iterator over the main diagonal of this matrix.
1215     *
1216     * If you have a matrix such as:
1217     * ```ignore
1218     * [
1219     *    1, 2
1220     *    3, 4
1221     * ]
1222     * ```
1223     * then the iterator will yield [1, 4]. If you do not need to copy the
1224     * elements use [`diagonal_reference_iter`](Matrix::diagonal_reference_iter) instead.
1225     *
1226     * # Examples
1227     *
1228     * Computing a [trace](https://en.wikipedia.org/wiki/Trace_(linear_algebra))
1229     * ```
1230     * use easy_ml::matrices::Matrix;
1231     * let matrix = Matrix::from(vec![
1232     *     vec![ 1, 2, 3 ],
1233     *     vec![ 4, 5, 6 ],
1234     *     vec![ 7, 8, 9 ],
1235     * ]);
1236     * let trace: i32 = matrix.diagonal_iter().sum();
1237     * assert_eq!(trace, 1 + 5 + 9);
1238     * ```
1239     */
1240    pub fn diagonal_iter(&self) -> DiagonalIterator<T> {
1241        DiagonalIterator::new(self)
1242    }
1243
1244    /**
1245     * Creates a matrix of the provided size with all elements initialised to the provided value
1246     *
1247     * # Panics
1248     *
1249     * Panics if no values are provided. Note: this method erroneously did not validate its inputs
1250     * in Easy ML versions up to and including 1.7.0
1251     */
1252    #[track_caller]
1253    pub fn empty(value: T, size: (Row, Column)) -> Matrix<T> {
1254        assert!(size.0 > 0 && size.1 > 0, "Size must be at least 1x1");
1255        Matrix {
1256            data: vec![value; size.0 * size.1],
1257            rows: size.0,
1258            columns: size.1,
1259        }
1260    }
1261
1262    /**
1263     * Gets a copy of the value at this row and column. Rows and Columns are 0 indexed.
1264     *
1265     * # Panics
1266     *
1267     * Panics if the index is out of range.
1268     */
1269    #[track_caller]
1270    pub fn get(&self, row: Row, column: Column) -> T {
1271        assert!(
1272            row < self.rows(),
1273            "Row out of index, only have {} rows",
1274            self.rows()
1275        );
1276        assert!(
1277            column < self.columns(),
1278            "Column out of index, only have {} columns",
1279            self.columns()
1280        );
1281        self.data[self.get_index(row, column)].clone()
1282    }
1283
1284    /**
1285     * Similar to matrix.get(0, 0) in that this returns the element in the first
1286     * row and first column, except that this method will panic if the matrix is
1287     * not 1x1.
1288     *
1289     * This is provided as a convenience function when you want to convert a unit matrix
1290     * to a scalar, such as after taking a dot product of two vectors.
1291     *
1292     * # Example
1293     *
1294     * ```
1295     * use easy_ml::matrices::Matrix;
1296     * let x = Matrix::column(vec![ 1.0, 2.0, 3.0 ]);
1297     * let sum_of_squares: f64 = (x.transpose() * x).scalar();
1298     * ```
1299     *
1300     * # Panics
1301     *
1302     * Panics if the matrix is not 1x1
1303     */
1304    #[track_caller]
1305    pub fn scalar(&self) -> T {
1306        assert!(
1307            self.rows() == 1,
1308            "Cannot treat matrix as scalar as it has more than one row"
1309        );
1310        assert!(
1311            self.columns() == 1,
1312            "Cannot treat matrix as scalar as it has more than one column"
1313        );
1314        self.get(0, 0)
1315    }
1316
1317    /**
1318     * Applies a function to all values in the matrix, modifying
1319     * the matrix in place.
1320     */
1321    pub fn map_mut(&mut self, mapping_function: impl Fn(T) -> T) {
1322        for value in self.data.iter_mut() {
1323            *value = mapping_function(value.clone());
1324        }
1325    }
1326
1327    /**
1328     * Applies a function to all values and each value's index in the
1329     * matrix, modifying the matrix in place.
1330     */
1331    pub fn map_mut_with_index(&mut self, mapping_function: impl Fn(T, Row, Column) -> T) {
1332        self.row_major_reference_mut_iter()
1333            .with_index()
1334            .for_each(|((i, j), x)| {
1335                *x = mapping_function(x.clone(), i, j);
1336            });
1337    }
1338
1339    /**
1340     * Creates and returns a new matrix with all values from the original with the
1341     * function applied to each. This can be used to change the type of the matrix
1342     * such as creating a mask:
1343     * ```
1344     * use easy_ml::matrices::Matrix;
1345     * let x = Matrix::from(vec![
1346     *    vec![ 0.0, 1.2 ],
1347     *    vec![ 5.8, 6.9 ]]);
1348     * let y = x.map(|element| element > 2.0);
1349     * let result = Matrix::from(vec![
1350     *    vec![ false, false ],
1351     *    vec![ true, true ]]);
1352     * assert_eq!(&y, &result);
1353     * ```
1354     */
1355    pub fn map<U>(&self, mapping_function: impl Fn(T) -> U) -> Matrix<U>
1356    where
1357        U: Clone,
1358    {
1359        let mapped = self
1360            .data
1361            .iter()
1362            .map(|x| mapping_function(x.clone()))
1363            .collect();
1364        Matrix::from_flat_row_major(self.size(), mapped)
1365    }
1366
1367    /**
1368     * Creates and returns a new matrix with all values from the original
1369     * and the index of each value mapped by a function. This can be used
1370     * to perform elementwise operations that are not defined on the
1371     * Matrix type itself.
1372     *
1373     * # Exmples
1374     *
1375     * Matrix elementwise division:
1376     *
1377     * ```
1378     * use easy_ml::matrices::Matrix;
1379     * let x = Matrix::from(vec![
1380     *     vec![ 9.0, 2.0 ],
1381     *     vec![ 4.0, 3.0 ]]);
1382     * let y = Matrix::from(vec![
1383     *     vec![ 3.0, 2.0 ],
1384     *     vec![ 1.0, 3.0 ]]);
1385     * let z = x.map_with_index(|x, row, column| x / y.get(row, column));
1386     * let result = Matrix::from(vec![
1387     *     vec![ 3.0, 1.0 ],
1388     *     vec![ 4.0, 1.0 ]]);
1389     * assert_eq!(&z, &result);
1390     * ```
1391     */
1392    pub fn map_with_index<U>(&self, mapping_function: impl Fn(T, Row, Column) -> U) -> Matrix<U>
1393    where
1394        U: Clone,
1395    {
1396        let mapped = self
1397            .row_major_iter()
1398            .with_index()
1399            .map(|((i, j), x)| mapping_function(x, i, j))
1400            .collect();
1401        Matrix::from_flat_row_major(self.size(), mapped)
1402    }
1403
1404    /**
1405     * Inserts a new row into the Matrix at the provided index,
1406     * shifting other rows to the right and filling all entries with the
1407     * provided value. Rows are 0 indexed.
1408     *
1409     * # Panics
1410     *
1411     * This will panic if the row is greater than the number of rows in the matrix.
1412     */
1413    #[track_caller]
1414    pub fn insert_row(&mut self, row: Row, value: T) {
1415        assert!(
1416            row <= self.rows(),
1417            "Row to insert must be <= to {}",
1418            self.rows()
1419        );
1420        for column in 0..self.columns() {
1421            self.data.insert(self.get_index(row, column), value.clone());
1422        }
1423        self.rows += 1;
1424    }
1425
1426    /**
1427     * Inserts a new row into the Matrix at the provided index, shifting other rows
1428     * to the right and filling all entries with the values from the iterator in sequence.
1429     * Rows are 0 indexed.
1430     *
1431     * # Panics
1432     *
1433     * This will panic if the row is greater than the number of rows in the matrix,
1434     * or if the iterator has fewer elements than `self.columns()`.
1435     *
1436     * Example of duplicating a row:
1437     * ```
1438     * use easy_ml::matrices::Matrix;
1439     * let x: Matrix<u8> = Matrix::row(vec![ 1, 2, 3 ]);
1440     * let mut y = x.clone();
1441     * // duplicate the first row as the second row
1442     * y.insert_row_with(1, x.row_iter(0));
1443     * assert_eq!((2, 3), y.size());
1444     * let mut values = y.column_major_iter();
1445     * assert_eq!(Some(1), values.next());
1446     * assert_eq!(Some(1), values.next());
1447     * assert_eq!(Some(2), values.next());
1448     * assert_eq!(Some(2), values.next());
1449     * assert_eq!(Some(3), values.next());
1450     * assert_eq!(Some(3), values.next());
1451     * assert_eq!(None, values.next());
1452     * ```
1453     */
1454    #[track_caller]
1455    pub fn insert_row_with<I>(&mut self, row: Row, mut values: I)
1456    where
1457        I: Iterator<Item = T>,
1458    {
1459        assert!(
1460            row <= self.rows(),
1461            "Row to insert must be <= to {}",
1462            self.rows()
1463        );
1464        for column in 0..self.columns() {
1465            self.data.insert(
1466                self.get_index(row, column),
1467                values.next().unwrap_or_else(|| {
1468                    panic!("At least {} values must be provided", self.columns())
1469                }),
1470            );
1471        }
1472        self.rows += 1;
1473    }
1474
1475    /**
1476     * Inserts a new column into the Matrix at the provided index, shifting other
1477     * columns to the right and filling all entries with the provided value.
1478     * Columns are 0 indexed.
1479     *
1480     * # Panics
1481     *
1482     * This will panic if the column is greater than the number of columns in the matrix.
1483     */
1484    #[track_caller]
1485    pub fn insert_column(&mut self, column: Column, value: T) {
1486        assert!(
1487            column <= self.columns(),
1488            "Column to insert must be <= to {}",
1489            self.columns()
1490        );
1491        for row in (0..self.rows()).rev() {
1492            self.data.insert(self.get_index(row, column), value.clone());
1493        }
1494        self.columns += 1;
1495    }
1496
1497    /**
1498     * Inserts a new column into the Matrix at the provided index, shifting other columns
1499     * to the right and filling all entries with the values from the iterator in sequence.
1500     * Columns are 0 indexed.
1501     *
1502     * # Panics
1503     *
1504     * This will panic if the column is greater than the number of columns in the matrix,
1505     * or if the iterator has fewer elements than `self.rows()`.
1506     *
1507     * Example of duplicating a column:
1508     * ```
1509     * use easy_ml::matrices::Matrix;
1510     * let x: Matrix<u8> = Matrix::column(vec![ 1, 2, 3 ]);
1511     * let mut y = x.clone();
1512     * // duplicate the first column as the second column
1513     * y.insert_column_with(1, x.column_iter(0));
1514     * assert_eq!((3, 2), y.size());
1515     * let mut values = y.column_major_iter();
1516     * assert_eq!(Some(1), values.next());
1517     * assert_eq!(Some(2), values.next());
1518     * assert_eq!(Some(3), values.next());
1519     * assert_eq!(Some(1), values.next());
1520     * assert_eq!(Some(2), values.next());
1521     * assert_eq!(Some(3), values.next());
1522     * assert_eq!(None, values.next());
1523     * ```
1524     */
1525    #[track_caller]
1526    pub fn insert_column_with<I>(&mut self, column: Column, values: I)
1527    where
1528        I: Iterator<Item = T>,
1529    {
1530        assert!(
1531            column <= self.columns(),
1532            "Column to insert must be <= to {}",
1533            self.columns()
1534        );
1535        let mut array_values = values.collect::<Vec<T>>();
1536        assert!(
1537            array_values.len() >= self.rows(),
1538            "At least {} values must be provided",
1539            self.rows()
1540        );
1541        for row in (0..self.rows()).rev() {
1542            self.data
1543                .insert(self.get_index(row, column), array_values.pop().unwrap());
1544        }
1545        self.columns += 1;
1546    }
1547
1548    /**
1549     * Makes a copy of this matrix shrunk down in size according to the slice. See
1550     * [retain_mut](Matrix::retain_mut()).
1551     */
1552    pub fn retain(&self, slice: Slice2D) -> Matrix<T> {
1553        let mut retained = self.clone();
1554        retained.retain_mut(slice);
1555        retained
1556    }
1557}
1558
1559/**
1560 * Any matrix of a Cloneable type implements Clone.
1561 */
1562impl<T: Clone> Clone for Matrix<T> {
1563    fn clone(&self) -> Self {
1564        self.map(|element| element)
1565    }
1566}
1567
1568/**
1569 * Any matrix of a Displayable type implements Display
1570 *
1571 * You can control the precision of the formatting using format arguments, i.e.
1572 * `format!("{:.3}", matrix)`
1573 */
1574impl<T: std::fmt::Display> std::fmt::Display for Matrix<T> {
1575    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1576        crate::matrices::views::format_view(self, f)
1577    }
1578}
1579
1580/**
1581 * Any matrix and two different dimension names can be converted to a 2 dimensional tensor with
1582 * the same number of rows and columns.
1583 *
1584 * Conversion will fail if the dimension names for `self.rows()` and `self.columns()` respectively
1585 * are the same.
1586 */
1587impl<T> TryFrom<(Matrix<T>, [crate::tensors::Dimension; 2])> for crate::tensors::Tensor<T, 2> {
1588    type Error = crate::tensors::InvalidShapeError<2>;
1589
1590    fn try_from(value: (Matrix<T>, [crate::tensors::Dimension; 2])) -> Result<Self, Self::Error> {
1591        let (matrix, [row_name, column_name]) = value;
1592        let shape = [(row_name, matrix.rows), (column_name, matrix.columns)];
1593        let check = crate::tensors::InvalidShapeError::new(shape);
1594        if !check.is_valid() {
1595            return Err(check);
1596        }
1597        // Now we know the shape is valid, we can call the standard Tensor constructor knowing
1598        // it won't fail since our data length will match the size of our shape.
1599        Ok(crate::tensors::Tensor::from(shape, matrix.data))
1600    }
1601}
1602
1603/**
1604 * Methods for matrices with numerical types, such as f32 or f64.
1605 *
1606 * Note that unsigned integers are not Numeric because they do not
1607 * implement [Neg](std::ops::Neg). You must first
1608 * wrap unsigned integers via [Wrapping](std::num::Wrapping) or [Saturating](std::num::Saturating).
1609 *
1610 * While these methods will all be defined on signed integer types as well, such as i16 or i32,
1611 * in many cases integers cannot be used sensibly in these computations. If you
1612 * have a matrix of type i8 for example, you should consider mapping it into a floating
1613 * type before doing heavy linear algebra maths on it.
1614 *
1615 * Determinants can be computed without loss of precision using sufficiently large signed
1616 * integers because the only operations performed on the elements are addition, subtraction
1617 * and mulitplication. However the inverse of a matrix such as
1618 *
1619 * ```ignore
1620 * [
1621 *   4, 7
1622 *   2, 8
1623 * ]
1624 * ```
1625 *
1626 * is
1627 *
1628 * ```ignore
1629 * [
1630 *   0.6, -0.7,
1631 *  -0.2, 0.4
1632 * ]
1633 * ```
1634 *
1635 * which requires a type that supports decimals to accurately represent.
1636 *
1637 * Mapping matrix type example:
1638 * ```
1639 * use easy_ml::matrices::Matrix;
1640 * use std::num::Wrapping;
1641 *
1642 * let matrix: Matrix<u8> = Matrix::from(vec![
1643 *     vec![ 2, 3 ],
1644 *     vec![ 6, 0 ]
1645 * ]);
1646 * // determinant is not defined on this matrix because u8 is not Numeric
1647 * // println!("{:?}", matrix.determinant()); // won't compile
1648 * // however Wrapping<u8> is numeric
1649 * let matrix = matrix.map(|element| Wrapping(element));
1650 * println!("{:?}", matrix.determinant()); // -> 238 (overflow)
1651 * println!("{:?}", matrix.map(|element| element.0 as i16).determinant()); // -> -18
1652 * println!("{:?}", matrix.map(|element| element.0 as f32).determinant()); // -> -18.0
1653 * ```
1654 */
1655impl<T: Numeric> Matrix<T>
1656where
1657    for<'a> &'a T: NumericRef<T>,
1658{
1659    /**
1660     * Returns the determinant of this square matrix, or None if the matrix
1661     * does not have a determinant. See [`linear_algebra`](super::linear_algebra::determinant())
1662     */
1663    pub fn determinant(&self) -> Option<T> {
1664        linear_algebra::determinant::<T>(self)
1665    }
1666
1667    /**
1668     * Computes the inverse of a matrix provided that it exists. To have an inverse a
1669     * matrix must be square (same number of rows and columns) and it must also have a
1670     * non zero determinant. See [`linear_algebra`](super::linear_algebra::inverse())
1671     */
1672    pub fn inverse(&self) -> Option<Matrix<T>> {
1673        linear_algebra::inverse::<T>(self)
1674    }
1675
1676    /**
1677     * Computes the covariance matrix for this NxM feature matrix, in which
1678     * each N'th row has M features to find the covariance and variance of. See
1679     * [`linear_algebra`](super::linear_algebra::covariance_column_features())
1680     */
1681    pub fn covariance_column_features(&self) -> Matrix<T> {
1682        linear_algebra::covariance_column_features::<T>(self)
1683    }
1684
1685    /**
1686     * Computes the covariance matrix for this NxM feature matrix, in which
1687     * each M'th column has N features to find the covariance and variance of. See
1688     * [`linear_algebra`](super::linear_algebra::covariance_row_features())
1689     */
1690    pub fn covariance_row_features(&self) -> Matrix<T> {
1691        linear_algebra::covariance_row_features::<T>(self)
1692    }
1693}
1694
1695/**
1696 * Methods for matrices with numerical real valued types, such as f32 or f64.
1697 *
1698 * This excludes signed and unsigned integers as they do not support decimal
1699 * precision and hence can't be used for operations like square roots.
1700 *
1701 * Third party fixed precision and infinite precision decimal types should
1702 * be able to implement all of the methods for [Real] and then utilise these functions.
1703 */
1704impl<T: Real> Matrix<T>
1705where
1706    for<'a> &'a T: RealRef<T>,
1707{
1708    /**
1709     * Computes the [L2 norm](https://en.wikipedia.org/wiki/Euclidean_vector#Length)
1710     * of this row or column vector, also referred to as the length or magnitude,
1711     * and written as ||x||, or sometimes |x|.
1712     *
1713     * ||**a**|| = sqrt(a<sub>1</sub><sup>2</sup> + a<sub>2</sub><sup>2</sup> + a<sub>3</sub><sup>2</sup>...) = sqrt(**a**<sup>T</sup> * **a**)
1714     *
1715     * This is a shorthand for `(x.transpose() * x).scalar().sqrt()` for
1716     * column vectors and `(x * x.transpose()).scalar().sqrt()` for row vectors, ie
1717     * the square root of the dot product of a vector with itself.
1718     *
1719     * The euclidean length can be used to compute a
1720     * [unit vector](https://en.wikipedia.org/wiki/Unit_vector), that is, a
1721     * vector with length of 1. This should not be confused with a unit matrix,
1722     * which is another name for an identity matrix.
1723     *
1724     * ```
1725     * use easy_ml::matrices::Matrix;
1726     * let a = Matrix::column(vec![ 1.0, 2.0, 3.0 ]);
1727     * let length = a.euclidean_length(); // (1^2 + 2^2 + 3^2)^0.5
1728     * let unit = a / length;
1729     * assert_eq!(unit.euclidean_length(), 1.0);
1730     * ```
1731     *
1732     * # Panics
1733     *
1734     * If the matrix is not a vector, ie if it has more than one row and more than one
1735     * column.
1736     */
1737    #[track_caller]
1738    pub fn euclidean_length(&self) -> T {
1739        if self.columns() == 1 {
1740            // column vector
1741            (self.transpose() * self).scalar().sqrt()
1742        } else if self.rows() == 1 {
1743            // row vector
1744            (self * self.transpose()).scalar().sqrt()
1745        } else {
1746            panic!(
1747                "Cannot compute unit vector of a non vector, rows: {}, columns: {}",
1748                self.rows(),
1749                self.columns()
1750            );
1751        }
1752    }
1753}
1754
1755// FIXME: want this to be callable in the main numeric impl block
1756impl<T: Numeric> Matrix<T> {
1757    /**
1758     * Creates a diagonal matrix of the provided size with the diagonal elements
1759     * set to the provided value and all other elements in the matrix set to 0.
1760     * A diagonal matrix is always square.
1761     *
1762     * The size is still taken as a tuple to facilitate creating a diagonal matrix
1763     * from the dimensionality of an existing one. If the provided value is 1 then
1764     * this will create an identity matrix.
1765     *
1766     * A 3 x 3 identity matrix:
1767     * ```ignore
1768     * [
1769     *   1, 0, 0
1770     *   0, 1, 0
1771     *   0, 0, 1
1772     * ]
1773     * ```
1774     *
1775     * # Panics
1776     *
1777     * If the provided size is not square.
1778     */
1779    #[track_caller]
1780    pub fn diagonal(value: T, size: (Row, Column)) -> Matrix<T> {
1781        assert!(size.0 == size.1);
1782        let mut matrix = Matrix::empty(T::zero(), size);
1783        for i in 0..size.0 {
1784            matrix.set(i, i, value.clone());
1785        }
1786        matrix
1787    }
1788
1789    /**
1790     * Creates a diagonal matrix with the elements along the diagonal set to the
1791     * provided values and all other elements in the matrix set to 0.
1792     * A diagonal matrix is always square.
1793     *
1794     * Examples
1795     *
1796     * ```
1797     * use easy_ml::matrices::Matrix;
1798     * let matrix = Matrix::from_diagonal(vec![ 1, 1, 1 ]);
1799     * assert_eq!(matrix.size(), (3, 3));
1800     * let copy = Matrix::from_diagonal(matrix.diagonal_iter().collect());
1801     * assert_eq!(matrix, copy);
1802     * assert_eq!(matrix, Matrix::from(vec![
1803     *     vec![ 1, 0, 0 ],
1804     *     vec![ 0, 1, 0 ],
1805     *     vec![ 0, 0, 1 ],
1806     * ]))
1807     * ```
1808     */
1809    pub fn from_diagonal(values: Vec<T>) -> Matrix<T> {
1810        let mut matrix = Matrix::empty(T::zero(), (values.len(), values.len()));
1811        for (i, element) in values.into_iter().enumerate() {
1812            matrix.set(i, i, element);
1813        }
1814        matrix
1815    }
1816}
1817
1818/**
1819 * PartialEq is implemented as two matrices are equal if and only if all their elements
1820 * are equal and they have the same size.
1821 */
1822impl<T: PartialEq> PartialEq for Matrix<T> {
1823    #[inline]
1824    fn eq(&self, other: &Self) -> bool {
1825        if self.rows() != other.rows() {
1826            return false;
1827        }
1828        if self.columns() != other.columns() {
1829            return false;
1830        }
1831        // perform elementwise check, return true only if every element in
1832        // each matrix is the same
1833        self.data.iter().zip(other.data.iter()).all(|(x, y)| x == y)
1834    }
1835}
1836
1837#[test]
1838fn test_sync() {
1839    fn assert_sync<T: Sync>() {}
1840    assert_sync::<Matrix<f64>>();
1841}
1842
1843#[test]
1844fn test_send() {
1845    fn assert_send<T: Send>() {}
1846    assert_send::<Matrix<f64>>();
1847}
1848
1849#[cfg(feature = "serde")]
1850mod serde_impls {
1851    use crate::matrices::{Column, Matrix, Row};
1852    use serde::{Deserialize, Deserializer};
1853
1854    #[derive(Deserialize)]
1855    #[serde(rename = "Matrix")]
1856    struct MatrixDeserialize<T> {
1857        data: Vec<T>,
1858        rows: Row,
1859        columns: Column,
1860    }
1861
1862    impl<'de, T> Deserialize<'de> for Matrix<T>
1863    where
1864        T: Deserialize<'de>,
1865    {
1866        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1867        where
1868            D: Deserializer<'de>,
1869        {
1870            MatrixDeserialize::<T>::deserialize(deserializer).map(|d| {
1871                // Safety: Use the no copy constructor that performs validation to prevent invalid
1872                // serialized data being created as a Matrix, which would then break all the
1873                // code that's relying on these invariants.
1874                Matrix::from_flat_row_major((d.rows, d.columns), d.data)
1875            })
1876        }
1877    }
1878}
1879
1880#[cfg(feature = "serde")]
1881#[test]
1882fn test_serialize() {
1883    fn assert_serialize<T: Serialize>() {}
1884    assert_serialize::<Matrix<f64>>();
1885}
1886
1887#[cfg(feature = "serde")]
1888#[test]
1889fn test_deserialize() {
1890    use serde::Deserialize;
1891    fn assert_deserialize<'de, T: Deserialize<'de>>() {}
1892    assert_deserialize::<Matrix<f64>>();
1893}
1894
1895#[cfg(feature = "serde")]
1896#[test]
1897fn test_serialization_deserialization_loop() {
1898    #[rustfmt::skip]
1899    let matrix = Matrix::from(vec![
1900        vec![1,  2,  3,  4],
1901        vec![5,  6,  7,  8],
1902        vec![9, 10, 11, 12],
1903    ]);
1904    let encoded = toml::to_string(&matrix).unwrap();
1905    assert_eq!(
1906        encoded,
1907        r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1908rows = 3
1909columns = 4
1910"#,
1911    );
1912    let parsed: Result<Matrix<i32>, _> = toml::from_str(&encoded);
1913    assert!(parsed.is_ok());
1914    assert_eq!(matrix, parsed.unwrap())
1915}
1916
1917#[cfg(feature = "serde")]
1918#[test]
1919#[should_panic]
1920fn test_deserialization_validation() {
1921    let _result: Result<Matrix<i32>, _> = toml::from_str(
1922        r#"data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1923rows = 3
1924columns = 3
1925"#,
1926    );
1927}
1928
1929#[test]
1930fn test_indexing() {
1931    let a = Matrix::from(vec![vec![1, 2], vec![3, 4]]);
1932    assert_eq!(a.get_index(0, 1), 1);
1933    assert_eq!(a.get_row_column(1), (0, 1));
1934    assert_eq!(a.get(0, 1), 2);
1935    let b = Matrix::from(vec![vec![1, 2, 3], vec![5, 6, 7]]);
1936    assert_eq!(b.get_index(1, 2), 5);
1937    assert_eq!(b.get_row_column(5), (1, 2));
1938    assert_eq!(b.get(1, 2), 7);
1939    assert_eq!(
1940        Matrix::from(vec![vec![0, 0], vec![0, 0], vec![0, 0]])
1941            .map_with_index(|_, r, c| format!("{:?}x{:?}", r, c)),
1942        Matrix::from(vec![
1943            vec!["0x0", "0x1"],
1944            vec!["1x0", "1x1"],
1945            vec!["2x0", "2x1"]
1946        ])
1947        .map(|x| x.to_owned())
1948    );
1949}