1use std::ops::{Add, Neg, Sub, Mul, Index, IndexMut, AddAssign, SubAssign, MulAssign, Range};
2use std::fmt::Debug;
3use nalgebra::{ClosedSubAssign, ClosedMulAssign};
4use nalgebra_sparse::na::{Scalar, ClosedAddAssign, DMatrix};
5use delegate::delegate;
6use derive_more::Display;
7use auto_impl_ops::auto_ops;
8use num_traits::{Zero, One};
9use crate::MatTrait;
10use crate::sparse::SpMat;
11
12#[derive(Clone, Debug, Display, PartialEq, Eq)]
13pub struct Mat<R> {
14 inner: DMatrix<R>
15}
16
17impl<R> MatTrait for Mat<R> {
18 fn shape(&self) -> (usize, usize) {
19 (self.inner.nrows(), self.inner.ncols())
20 }
21}
22
23impl<R> Mat<R> {
24 pub fn inner(&self) -> &DMatrix<R> {
25 &self.inner
26 }
27
28 pub fn inner_mut(&mut self) -> &mut DMatrix<R> {
29 &mut self.inner
30 }
31
32 pub fn into_inner(self) -> DMatrix<R> {
33 self.inner
34 }
35
36 pub fn iter(&self) -> impl Iterator<Item = (usize, usize, &R)> {
37 let m = self.nrows();
38 self.inner.iter().enumerate().map(move |(i, a)|
39 (i % m, i / m, a)
40 )
41 }
42}
43
44impl<R> Mat<R>
45where R: Scalar {
46 pub fn from_data<I>(shape: (usize, usize), data: I) -> Self
47 where I: IntoIterator<Item = R> {
48 DMatrix::from_row_iterator(shape.0, shape.1, data).into()
49 }
50
51 pub fn zero(shape: (usize, usize)) -> Self
52 where R: Zero {
53 let inner = DMatrix::zeros(shape.0, shape.1);
54 Self::from(inner)
55 }
56
57 pub fn is_zero(&self) -> bool
58 where R: Zero {
59 self.iter().all(|e| e.2.is_zero())
60 }
61
62 pub fn id(size: usize) -> Self
63 where R: Zero + One {
64 let inner = DMatrix::identity(size, size);
65 Self::from(inner)
66 }
67
68 pub fn is_id(&self) -> bool
69 where R: Zero + One {
70 self.is_square() && self.iter().all(|(i, j, a)|
71 i == j && a.is_one() ||
72 i != j && a.is_zero()
73 )
74 }
75
76 pub fn diag<I>(shape: (usize, usize), entries: I) -> Self
77 where R: Zero, I: IntoIterator<Item = R> {
78 let mut mat = Self::zero(shape);
79 for (i, a) in entries.into_iter().enumerate() {
80 mat[(i, i)] = a;
81 }
82 mat
83 }
84
85 pub fn is_diag(&self) -> bool
86 where R: Zero {
87 self.iter().all(|(i, j, a)|
88 i == j || a.is_zero()
89 )
90 }
91
92 pub fn submat(&self, rows: Range<usize>, cols: Range<usize>) -> Mat<R> {
93 let (i0, i1) = (rows.start, rows.end);
94 let (j0, j1) = (cols.start, cols.end);
95
96 assert!(i0 <= i1 && i1 <= self.nrows());
97 assert!(j0 <= j1 && j1 <= self.ncols());
98
99 let slice = self.inner.view((i0, j0), (i1 - i0, j1 - j0));
100 Self::from(slice.clone_owned())
101 }
102
103 pub fn submat_rows(&self, rows: Range<usize>) -> Mat<R> {
104 let n = self.ncols();
105 self.submat(rows, 0 .. n)
106 }
107
108 pub fn submat_cols(&self, cols: Range<usize>) -> Mat<R> {
109 let m = self.nrows();
110 self.submat(0 .. m, cols)
111 }
112
113 pub fn into_sparse(self) -> SpMat<R>
114 where R: Zero + ClosedAddAssign {
115 self.into()
116 }
117}
118
119impl<R> From<DMatrix<R>> for Mat<R> {
120 fn from(inner: DMatrix<R>) -> Self {
121 Self { inner }
122 }
123}
124
125impl<R> From<SpMat<R>> for Mat<R>
126where R: Scalar + Zero + ClosedAddAssign {
127 fn from(value: SpMat<R>) -> Self {
128 let inner = DMatrix::from(value.inner());
129 Self::from(inner)
130 }
131}
132
133impl<R> Index<(usize, usize)> for Mat<R> {
134 type Output = R;
135 delegate! {
136 to self.inner {
137 fn index(&self, index: (usize, usize)) -> &R;
138 }
139 }
140}
141
142impl<R> IndexMut<(usize, usize)> for Mat<R> {
143 delegate! {
144 to self.inner {
145 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output;
146 }
147 }
148}
149
150impl<R> Default for Mat<R>
151where R: Scalar + Zero {
152 fn default() -> Self {
153 Self::zero((0, 0))
154 }
155}
156
157impl<R> Neg for Mat<R>
158where R: Scalar + Neg<Output = R> {
159 type Output = Self;
160 fn neg(self) -> Self::Output {
161 Mat::from(-self.inner)
162 }
163}
164
165impl<R> Neg for &Mat<R>
166where R: Scalar + Neg<Output = R> {
167 type Output = Mat<R>;
168 fn neg(self) -> Self::Output {
169 Mat::from(-&self.inner)
170 }
171}
172
173#[auto_ops]
174impl<R> AddAssign<&Mat<R>> for Mat<R>
175where R: Scalar + ClosedAddAssign {
176 fn add_assign(&mut self, rhs: &Self) {
177 self.inner += &rhs.inner;
178 }
179}
180
181#[auto_ops]
182impl<R> SubAssign<&Mat<R>> for Mat<R>
183where R: Scalar + ClosedSubAssign {
184 fn sub_assign(&mut self, rhs: &Self) {
185 self.inner -= &rhs.inner
186 }
187}
188
189#[auto_ops]
190impl<'a, 'b, R> Mul<&'b Mat<R>> for &'a Mat<R>
191where R: Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign {
192 type Output = Mat<R>;
193 fn mul(self, rhs: &'b Mat<R>) -> Self::Output {
194 let prod = &self.inner * &rhs.inner;
195 Mat::from(prod)
196 }
197}
198
199impl<R> Mat<R>
200where R: Scalar {
201 pub fn swap_rows(&mut self, i: usize, j: usize) {
202 self.inner.swap_rows(i, j);
203 }
204
205 pub fn swap_cols(&mut self, i: usize, j: usize) {
206 self.inner.swap_columns(i, j);
207 }
208
209 pub fn mul_row(&mut self, i: usize, r: &R)
210 where R: ClosedMulAssign {
211 self.inner.row_mut(i).mul_assign(r.clone())
212 }
213
214 pub fn mul_col(&mut self, j: usize, r: &R)
215 where R: ClosedMulAssign {
216 self.inner.column_mut(j).mul_assign(r.clone())
217 }
218
219 pub fn add_row_to(&mut self, i: usize, j: usize, r: &R)
220 where R: ClosedAddAssign + ClosedMulAssign {
221 let row = self.inner.row(i).mul(r.clone());
222 self.inner.row_mut(j).add_assign(row)
223 }
224
225 pub fn add_col_to(&mut self, i: usize, j: usize, r: &R)
226 where R: ClosedAddAssign + ClosedMulAssign {
227 let col = self.inner.column(i).mul(r.clone());
228 self.inner.column_mut(j).add_assign(col)
229 }
230
231 pub fn left_elementary(&mut self, comps: [&R; 4], i: usize, j: usize)
233 where R: ClosedAddAssign + ClosedMulAssign {
234 let [a, b, c, d] = comps.map(Clone::clone);
235
236 let r_i = self.inner.row(i);
237 let r_j = self.inner.row(j);
238
239 let s_i = &r_i * a + &r_j * b;
240 let s_j = &r_i * c + &r_j * d;
241
242 self.inner.set_row(i, &s_i);
243 self.inner.set_row(j, &s_j);
244 }
245
246 pub fn right_elementary(&mut self, comps: [&R; 4], i: usize, j: usize)
248 where R: ClosedAddAssign + ClosedMulAssign {
249 let [a, b, c, d] = comps.map(Clone::clone);
250
251 let r_i = self.inner.column(i);
252 let r_j = self.inner.column(j);
253
254 let s_i = &r_i * a + &r_j * b;
255 let s_j = &r_i * c + &r_j * d;
256
257 self.inner.set_column(i, &s_i);
258 self.inner.set_column(j, &s_j);
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn init() {
268 let a = Mat::from_data((2, 3), [1,2,3,4,5,6]);
269
270 assert_eq!(a.nrows(), 2);
271 assert_eq!(a.ncols(), 3);
272 assert_eq!(a.into_inner(), DMatrix::from_row_slice(2, 3, &[1,2,3,4,5,6]));
273 }
274
275 #[test]
276 fn eq() {
277 let a = Mat::from_data((2, 3), [1,2,3,4,5,6]);
278 let b = Mat::from_data((2, 3), [1,2,0,4,5,6]);
279 let c = Mat::from_data((3, 2), [1,2,3,4,5,6]);
280
281 assert_eq!(a, a);
282 assert_ne!(a, b);
283 assert_ne!(a, c);
284 }
285
286 #[test]
287 fn square() {
288 let a: Mat<i32> = Mat::zero((3, 3));
289 assert!(a.is_square());
290
291 let a: Mat<i32> = Mat::zero((3, 2));
292 assert!(!a.is_square());
293 }
294
295 #[test]
296 fn zero() {
297 let a: Mat<i32> = Mat::zero((3, 2));
298 assert!(a.is_zero());
299
300 let a = Mat::from_data((2, 3), [1,2,3,4,5,6]);
301 assert!(!a.is_zero());
302 }
303
304 #[test]
305 fn id() {
306 let a: Mat<i32> = Mat::id(3);
307 assert!(a.is_id());
308
309 let a = Mat::from_data((2, 2), [1,2,3,4]);
310 assert!(!a.is_id());
311
312 let a = Mat::from_data((2, 3), [1,0,0,0,1,0]);
313 assert!(!a.is_id());
314 }
315
316 #[test]
317 fn swap_rows() {
318 let mut a = Mat::from_data((3, 4), 1..=12);
319 a.swap_rows(0, 1);
320 assert_eq!(a, Mat::from_data((3, 4), [5,6,7,8,1,2,3,4,9,10,11,12]));
321 }
322
323 #[test]
324 fn swap_cols() {
325 let mut a = Mat::from_data((3, 4), 1..=12);
326 a.swap_cols(0, 1);
327 assert_eq!(a, Mat::from_data((3, 4), [2,1,3,4,6,5,7,8,10,9,11,12]));
328 }
329
330 #[test]
331 fn mul_row() {
332 let mut a = Mat::from_data((3, 3), 1..=9);
333 a.mul_row(1, &10);
334 assert_eq!(a, Mat::from_data((3, 3), [1,2,3,40,50,60,7,8,9]));
335 }
336
337 #[test]
338 fn mul_col() {
339 let mut a = Mat::from_data((3, 3), 1..=9);
340 a.mul_col(1, &10);
341 assert_eq!(a, Mat::from_data((3, 3), [1,20,3,4,50,6,7,80,9]));
342 }
343
344 #[test]
345 fn add_row_to() {
346 let mut a = Mat::from_data((3, 3), 1..=9);
347 a.add_row_to(0, 1, &10);
348 assert_eq!(a, Mat::from_data((3, 3), [1,2,3,14,25,36,7,8,9]));
349 }
350
351 #[test]
352 fn add_col_to() {
353 let mut a = Mat::from_data((3, 3), 1..=9);
354 a.add_col_to(0, 1, &10);
355 assert_eq!(a, Mat::from_data((3, 3), [1,12,3,4,45,6,7,78,9]));
356 }
357
358 #[test]
359 fn add() {
360 let a = Mat::from_data((3, 2), [1,2,3,4,5,6]);
361 let b = Mat::from_data((3, 2), [8,2,4,0,2,1]);
362 let c = a + b;
363 assert_eq!(c, Mat::from_data((3, 2), [9,4,7,4,7,7]));
364 }
365
366 #[test]
367 fn sub() {
368 let a = Mat::from_data((3, 2), [1,2,3,4,5,6]);
369 let b = Mat::from_data((3, 2), [8,2,4,0,2,1]);
370 let c = a - b;
371 assert_eq!(c, Mat::from_data((3, 2), [-7,0,-1,4,3,5]));
372 }
373
374 #[test]
375 fn neg() {
376 let a = Mat::from_data((3, 2), [1,2,3,4,5,6]);
377 assert_eq!(-a, Mat::from_data((3, 2), [-1,-2,-3,-4,-5,-6]));
378 }
379
380 #[test]
381 fn mul() {
382 let a = Mat::from_data((2, 3), [1,2,3,4,5,6]);
383 let b = Mat::from_data((3, 2), [1,2,1,-1,0,2]);
384 let c = a * b;
385 assert_eq!(c, Mat::from_data((2, 2), [3,6,9,15]));
386 }
387
388 #[test]
389 fn to_sparse() {
390 let dns = Mat::from_data((2, 3), [1,2,3,4,5,6]);
391 let sps = dns.into_sparse();
392 assert_eq!(sps, SpMat::from_dense_data((2, 3), [1,2,3,4,5,6]));
393 }
394
395 #[test]
396 fn from_sparse() {
397 let sps = SpMat::from_dense_data((2, 3), [1,2,3,4,5,6]);
398 let dns = Mat::from(sps);
399 assert_eq!(dns, Mat::from_data((2, 3), [1,2,3,4,5,6]));
400 }
401
402 #[test]
403 fn submat() {
404 let a = Mat::from_data((3, 4), [
405 1, 2, 3, 7,
406 4, 5, 6, 8,
407 9,10,11,12
408 ]);
409 let b = a.submat(1..3, 2..4);
410 assert_eq!(b, Mat::from_data((2, 2), [
411 6, 8,
412 11,12
413 ]));
414 }
415}