linalg_traits/matrix/
mat.rs1use crate::Vector;
2use crate::matrix::matrix_trait::Matrix;
3use crate::scalar::Scalar;
4use std::borrow::Cow;
5use std::iter::Iterator;
6use std::ops::{Index, IndexMut};
7
8#[derive(Clone, Debug, PartialEq)]
21pub struct Mat<S>
22where
23 S: Scalar,
24{
25 data: Vec<S>,
26 rows: usize,
27 cols: usize,
28}
29
30impl<S> Mat<S>
31where
32 S: Scalar,
33{
34 fn index(&self, row: usize, col: usize) -> usize {
36 assert!(row < self.rows && col < self.cols, "Index out of bounds");
37 row * self.cols + col
38 }
39
40 pub fn iter(&self) -> impl Iterator<Item = &S> {
46 self.data.iter()
47 }
48
49 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut S> {
55 self.data.iter_mut()
56 }
57}
58
59impl<S> IntoIterator for Mat<S>
60where
61 S: Scalar,
62{
63 type Item = S;
64 type IntoIter = std::vec::IntoIter<S>;
65
66 fn into_iter(self) -> Self::IntoIter {
67 self.data.into_iter()
68 }
69}
70
71impl<S: Scalar> Index<(usize, usize)> for Mat<S> {
72 type Output = S;
73 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
74 &self.data[self.index(row, col)]
75 }
76}
77
78impl<S: Scalar> IndexMut<(usize, usize)> for Mat<S> {
79 fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
80 let idx = self.index(row, col);
81 &mut self.data[idx]
82 }
83}
84
85impl<S> Matrix<S> for Mat<S>
86where
87 S: Scalar,
88{
89 type VectorM = Vec<S>;
90
91 type VectorN = Vec<S>;
92
93 fn is_statically_sized() -> bool {
94 false
95 }
96
97 fn is_dynamically_sized() -> bool {
98 true
99 }
100
101 fn is_row_major() -> bool {
102 true
103 }
104
105 fn is_column_major() -> bool {
106 false
107 }
108
109 fn new_with_shape(rows: usize, cols: usize) -> Self {
110 Mat {
111 data: vec![S::zero(); rows * cols],
112 rows,
113 cols,
114 }
115 }
116
117 fn shape(&self) -> (usize, usize) {
118 (self.rows, self.cols)
119 }
120
121 fn from_row_slice(rows: usize, cols: usize, slice: &[S]) -> Self {
122 assert_eq!(
123 slice.len(),
124 rows * cols,
125 "Slice length ({}) not compatible with matrix dimensions ({}x{}).",
126 slice.len(),
127 rows,
128 cols,
129 );
130 Mat {
131 data: slice.to_vec(),
132 rows,
133 cols,
134 }
135 }
136
137 fn from_col_slice(rows: usize, cols: usize, slice: &[S]) -> Self {
138 assert_eq!(
139 slice.len(),
140 rows * cols,
141 "Slice length ({}) not compatible with matrix dimensions ({}x{}).",
142 slice.len(),
143 rows,
144 cols,
145 );
146 let mut data = Vec::with_capacity(rows * cols);
147 for row in 0..rows {
148 for col in 0..cols {
149 data.push(slice[row + col * rows]);
150 }
151 }
152 Mat { data, rows, cols }
153 }
154
155 fn as_slice<'a>(&'a self) -> Cow<'a, [S]> {
156 Cow::from(self.data.as_slice())
157 }
158
159 fn add(&self, other: &Self) -> Self {
160 self.assert_same_shape(other);
161 Mat {
162 data: self.data.add(&other.data),
163 rows: self.rows,
164 cols: self.cols,
165 }
166 }
167
168 fn add_assign(&mut self, other: &Self) {
169 self.assert_same_shape(other);
170 self.data.add_assign(&other.data);
171 }
172
173 fn sub(&self, other: &Self) -> Self {
174 self.assert_same_shape(other);
175 Mat {
176 data: self.data.sub(&other.data),
177 rows: self.rows,
178 cols: self.cols,
179 }
180 }
181
182 fn sub_assign(&mut self, other: &Self) {
183 self.assert_same_shape(other);
184 self.data.sub_assign(&other.data);
185 }
186
187 fn mul(&self, scalar: S) -> Self {
188 Mat {
189 data: self.data.mul(scalar),
190 rows: self.rows,
191 cols: self.cols,
192 }
193 }
194
195 fn mul_assign(&mut self, scalar: S) {
196 self.data.mul_assign(scalar);
197 }
198
199 fn div(&self, scalar: S) -> Self {
200 Mat {
201 data: self.data.div(scalar),
202 rows: self.rows,
203 cols: self.cols,
204 }
205 }
206
207 fn div_assign(&mut self, scalar: S) {
208 self.data.div_assign(scalar);
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_indexing() {
218 let mut mat = Mat::<f64>::new_with_shape(2, 2);
219 mat[(0, 0)] = 1.0;
220 mat[(0, 1)] = 2.0;
221 mat[(1, 0)] = 3.0;
222 mat[(1, 1)] = 4.0;
223 assert_eq!(mat[(0, 0)], 1.0);
224 assert_eq!(mat[(0, 1)], 2.0);
225 assert_eq!(mat[(1, 0)], 3.0);
226 assert_eq!(mat[(1, 1)], 4.0);
227 }
228}