linalg_traits/matrix/
mat.rs

1use crate::Vector;
2use crate::matrix::matrix_trait::Matrix;
3use crate::scalar::Scalar;
4use std::borrow::Cow;
5use std::iter::Iterator;
6use std::ops::{Index, IndexMut};
7
8/// Extremely basic matrix type, written as `Mat<S>`, short for "matrix".
9///
10/// # Implementation Details
11///
12/// * The underlying data structure is a [`Vec<S>`].
13/// * This matrix implementation is row-major; the elements of the matrix are stored row-by-row
14///   in a one-dimensional "flat" data structure (in this case a [`Vec<S>`]).
15///
16/// # Motivation
17///
18/// Rust does not have a matrix type in the `std` library, and users of this crate may not want to
19/// have dependencies such as [`nalgebra`], [`ndarray`], and/or [`faer`].
20#[derive(Clone, Debug, PartialEq)]
21pub struct Mat<S>
22where
23    S: Scalar,
24{
25    data: Vec<S>,
26    rows: usize,
27    cols: usize,
28}
29
30impl<S> Mat<S>
31where
32    S: Scalar,
33{
34    /// Helper function to calculate the linear index from row and column indices.
35    fn index(&self, row: usize, col: usize) -> usize {
36        assert!(row < self.rows && col < self.cols, "Index out of bounds");
37        row * self.cols + col
38    }
39
40    /// Returns an iterator over the elements of the matrix.
41    ///
42    /// # Returns
43    ///
44    /// An iterator that yields references to the elements of the matrix.
45    pub fn iter(&self) -> impl Iterator<Item = &S> {
46        self.data.iter()
47    }
48
49    /// Returns a mutable iterator over the elements of the matrix.
50    ///
51    /// # Returns
52    ///
53    /// An iterator that yields mutable references to the elements of the matrix.
54    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut S> {
55        self.data.iter_mut()
56    }
57}
58
59impl<S> IntoIterator for Mat<S>
60where
61    S: Scalar,
62{
63    type Item = S;
64    type IntoIter = std::vec::IntoIter<S>;
65
66    fn into_iter(self) -> Self::IntoIter {
67        self.data.into_iter()
68    }
69}
70
71impl<S: Scalar> Index<(usize, usize)> for Mat<S> {
72    type Output = S;
73    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
74        &self.data[self.index(row, col)]
75    }
76}
77
78impl<S: Scalar> IndexMut<(usize, usize)> for Mat<S> {
79    fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
80        let idx = self.index(row, col);
81        &mut self.data[idx]
82    }
83}
84
85impl<S> Matrix<S> for Mat<S>
86where
87    S: Scalar,
88{
89    type VectorM = Vec<S>;
90
91    type VectorN = Vec<S>;
92
93    fn is_statically_sized() -> bool {
94        false
95    }
96
97    fn is_dynamically_sized() -> bool {
98        true
99    }
100
101    fn is_row_major() -> bool {
102        true
103    }
104
105    fn is_column_major() -> bool {
106        false
107    }
108
109    fn new_with_shape(rows: usize, cols: usize) -> Self {
110        Mat {
111            data: vec![S::zero(); rows * cols],
112            rows,
113            cols,
114        }
115    }
116
117    fn shape(&self) -> (usize, usize) {
118        (self.rows, self.cols)
119    }
120
121    fn from_row_slice(rows: usize, cols: usize, slice: &[S]) -> Self {
122        assert_eq!(
123            slice.len(),
124            rows * cols,
125            "Slice length ({}) not compatible with matrix dimensions ({}x{}).",
126            slice.len(),
127            rows,
128            cols,
129        );
130        Mat {
131            data: slice.to_vec(),
132            rows,
133            cols,
134        }
135    }
136
137    fn from_col_slice(rows: usize, cols: usize, slice: &[S]) -> Self {
138        assert_eq!(
139            slice.len(),
140            rows * cols,
141            "Slice length ({}) not compatible with matrix dimensions ({}x{}).",
142            slice.len(),
143            rows,
144            cols,
145        );
146        let mut data = Vec::with_capacity(rows * cols);
147        for row in 0..rows {
148            for col in 0..cols {
149                data.push(slice[row + col * rows]);
150            }
151        }
152        Mat { data, rows, cols }
153    }
154
155    fn as_slice<'a>(&'a self) -> Cow<'a, [S]> {
156        Cow::from(self.data.as_slice())
157    }
158
159    fn add(&self, other: &Self) -> Self {
160        self.assert_same_shape(other);
161        Mat {
162            data: self.data.add(&other.data),
163            rows: self.rows,
164            cols: self.cols,
165        }
166    }
167
168    fn add_assign(&mut self, other: &Self) {
169        self.assert_same_shape(other);
170        self.data.add_assign(&other.data);
171    }
172
173    fn sub(&self, other: &Self) -> Self {
174        self.assert_same_shape(other);
175        Mat {
176            data: self.data.sub(&other.data),
177            rows: self.rows,
178            cols: self.cols,
179        }
180    }
181
182    fn sub_assign(&mut self, other: &Self) {
183        self.assert_same_shape(other);
184        self.data.sub_assign(&other.data);
185    }
186
187    fn mul(&self, scalar: S) -> Self {
188        Mat {
189            data: self.data.mul(scalar),
190            rows: self.rows,
191            cols: self.cols,
192        }
193    }
194
195    fn mul_assign(&mut self, scalar: S) {
196        self.data.mul_assign(scalar);
197    }
198
199    fn div(&self, scalar: S) -> Self {
200        Mat {
201            data: self.data.div(scalar),
202            rows: self.rows,
203            cols: self.cols,
204        }
205    }
206
207    fn div_assign(&mut self, scalar: S) {
208        self.data.div_assign(scalar);
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_indexing() {
218        let mut mat = Mat::<f64>::new_with_shape(2, 2);
219        mat[(0, 0)] = 1.0;
220        mat[(0, 1)] = 2.0;
221        mat[(1, 0)] = 3.0;
222        mat[(1, 1)] = 4.0;
223        assert_eq!(mat[(0, 0)], 1.0);
224        assert_eq!(mat[(0, 1)], 2.0);
225        assert_eq!(mat[(1, 0)], 3.0);
226        assert_eq!(mat[(1, 1)], 4.0);
227    }
228}