Skip to main content

yui_matrix/dense/
mat.rs

1use std::ops::{Add, Neg, Sub, Mul, Index, IndexMut, AddAssign, SubAssign, MulAssign, Range};
2use std::fmt::Debug;
3use nalgebra::{ClosedSubAssign, ClosedMulAssign};
4use nalgebra_sparse::na::{Scalar, ClosedAddAssign, DMatrix};
5use delegate::delegate;
6use derive_more::Display;
7use auto_impl_ops::auto_ops;
8use num_traits::{Zero, One};
9use crate::MatTrait;
10use crate::sparse::SpMat;
11
12#[derive(Clone, Debug, Display, PartialEq, Eq)]
13pub struct Mat<R> {
14    inner: DMatrix<R>
15}
16
17impl<R> MatTrait for Mat<R> {
18    fn shape(&self) -> (usize, usize) {
19        (self.inner.nrows(), self.inner.ncols())
20    }
21}
22
23impl<R> Mat<R> {
24    pub fn inner(&self) -> &DMatrix<R> {
25        &self.inner
26    }
27
28    pub fn inner_mut(&mut self) -> &mut DMatrix<R> {
29        &mut self.inner
30    }
31
32    pub fn into_inner(self) -> DMatrix<R> {
33        self.inner
34    }
35
36    pub fn iter(&self) -> impl Iterator<Item = (usize, usize, &R)> { 
37        let m = self.nrows();
38        self.inner.iter().enumerate().map(move |(i, a)| 
39            (i % m, i / m, a)
40        )
41    }
42}
43
44impl<R> Mat<R>
45where R: Scalar {
46    pub fn from_data<I>(shape: (usize, usize), data: I) -> Self
47    where I: IntoIterator<Item = R> { 
48        DMatrix::from_row_iterator(shape.0, shape.1, data).into()
49    }
50
51    pub fn zero(shape: (usize, usize)) -> Self
52    where R: Zero { 
53        let inner = DMatrix::zeros(shape.0, shape.1);
54        Self::from(inner)
55    }
56
57    pub fn is_zero(&self) -> bool
58    where R: Zero { 
59        self.iter().all(|e| e.2.is_zero())
60    }
61
62    pub fn id(size: usize) -> Self
63    where R: Zero + One { 
64        let inner = DMatrix::identity(size, size);
65        Self::from(inner)
66    }
67
68    pub fn is_id(&self) -> bool
69    where R: Zero + One { 
70        self.is_square() && self.iter().all(|(i, j, a)| 
71            i == j && a.is_one() || 
72            i != j && a.is_zero()
73        )
74    }
75
76    pub fn diag<I>(shape: (usize, usize), entries: I) -> Self
77    where R: Zero, I: IntoIterator<Item = R> {
78        let mut mat = Self::zero(shape);
79        for (i, a) in entries.into_iter().enumerate() {
80            mat[(i, i)] = a;
81        }
82        mat
83    }
84
85    pub fn is_diag(&self) -> bool
86    where R: Zero { 
87        self.iter().all(|(i, j, a)| 
88            i == j || a.is_zero()
89        )
90    }
91
92    pub fn submat(&self, rows: Range<usize>, cols: Range<usize>) -> Mat<R> { 
93        let (i0, i1) = (rows.start, rows.end);
94        let (j0, j1) = (cols.start, cols.end);
95
96        assert!(i0 <= i1 && i1 <= self.nrows());
97        assert!(j0 <= j1 && j1 <= self.ncols());
98
99        let slice = self.inner.view((i0, j0), (i1 - i0, j1 - j0));
100        Self::from(slice.clone_owned())
101    }
102
103    pub fn submat_rows(&self, rows: Range<usize>) -> Mat<R> { 
104        let n = self.ncols();
105        self.submat(rows, 0 .. n)
106    }
107
108    pub fn submat_cols(&self, cols: Range<usize>) -> Mat<R> { 
109        let m = self.nrows();
110        self.submat(0 .. m, cols)
111    }
112
113    pub fn into_sparse(self) -> SpMat<R>
114    where R: Zero + ClosedAddAssign { 
115        self.into()
116    }
117}
118
119impl<R> From<DMatrix<R>> for Mat<R> {
120    fn from(inner: DMatrix<R>) -> Self {
121        Self { inner }
122    }
123}
124
125impl<R> From<SpMat<R>> for Mat<R>
126where R: Scalar + Zero + ClosedAddAssign {
127    fn from(value: SpMat<R>) -> Self {
128        let inner = DMatrix::from(value.inner());
129        Self::from(inner)
130    }
131}
132 
133impl<R> Index<(usize, usize)> for Mat<R> {
134    type Output = R;
135    delegate! { 
136        to self.inner { 
137            fn index(&self, index: (usize, usize)) -> &R;
138        }
139    }
140}
141
142impl<R> IndexMut<(usize, usize)> for Mat<R> {
143    delegate! { 
144        to self.inner { 
145            fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output;
146        }
147    }
148}
149
150impl<R> Default for Mat<R>
151where R: Scalar + Zero {
152    fn default() -> Self {
153        Self::zero((0, 0))
154    }
155}
156
157impl<R> Neg for Mat<R>
158where R: Scalar + Neg<Output = R> {
159    type Output = Self;
160    fn neg(self) -> Self::Output {
161        Mat::from(-self.inner)
162    }
163}
164
165impl<R> Neg for &Mat<R>
166where R: Scalar + Neg<Output = R> {
167    type Output = Mat<R>;
168    fn neg(self) -> Self::Output {
169        Mat::from(-&self.inner)
170    }
171}
172
173#[auto_ops]
174impl<R> AddAssign<&Mat<R>> for Mat<R>
175where R: Scalar + ClosedAddAssign {
176    fn add_assign(&mut self, rhs: &Self) {
177        self.inner += &rhs.inner;
178    }
179}
180
181#[auto_ops]
182impl<R> SubAssign<&Mat<R>> for Mat<R>
183where R: Scalar + ClosedSubAssign {
184    fn sub_assign(&mut self, rhs: &Self) {
185        self.inner -= &rhs.inner
186    }
187}
188
189#[auto_ops]
190impl<'a, 'b, R> Mul<&'b Mat<R>> for &'a Mat<R>
191where R: Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign {
192    type Output = Mat<R>;
193    fn mul(self, rhs: &'b Mat<R>) -> Self::Output {
194        let prod = &self.inner * &rhs.inner;
195        Mat::from(prod)
196    }
197}
198
199impl<R> Mat<R>
200where R: Scalar { 
201    pub fn swap_rows(&mut self, i: usize, j: usize) {
202        self.inner.swap_rows(i, j);
203    }
204
205    pub fn swap_cols(&mut self, i: usize, j: usize) {
206        self.inner.swap_columns(i, j);
207    }
208
209    pub fn mul_row(&mut self, i: usize, r: &R)
210    where R: ClosedMulAssign {
211        self.inner.row_mut(i).mul_assign(r.clone())
212    }
213
214    pub fn mul_col(&mut self, j: usize, r: &R)
215    where R: ClosedMulAssign {
216        self.inner.column_mut(j).mul_assign(r.clone())
217    }
218
219    pub fn add_row_to(&mut self, i: usize, j: usize, r: &R)
220    where R: ClosedAddAssign + ClosedMulAssign { 
221        let row = self.inner.row(i).mul(r.clone());
222        self.inner.row_mut(j).add_assign(row)
223    }
224
225    pub fn add_col_to(&mut self, i: usize, j: usize, r: &R)
226    where R: ClosedAddAssign + ClosedMulAssign {  
227        let col = self.inner.column(i).mul(r.clone());
228        self.inner.column_mut(j).add_assign(col)
229    }
230
231    // Multiply [a, b; c, d] from left. 
232    pub fn left_elementary(&mut self, comps: [&R; 4], i: usize, j: usize)
233    where R: ClosedAddAssign + ClosedMulAssign { 
234        let [a, b, c, d] = comps.map(Clone::clone);
235
236        let r_i = self.inner.row(i);
237        let r_j = self.inner.row(j);
238        
239        let s_i = &r_i * a + &r_j * b;
240        let s_j = &r_i * c + &r_j * d;
241
242        self.inner.set_row(i, &s_i);
243        self.inner.set_row(j, &s_j);
244    }
245
246    // Multiply [a, c; b, d] from right. 
247    pub fn right_elementary(&mut self, comps: [&R; 4], i: usize, j: usize) 
248    where R: ClosedAddAssign + ClosedMulAssign { 
249        let [a, b, c, d] = comps.map(Clone::clone);
250
251        let r_i = self.inner.column(i);
252        let r_j = self.inner.column(j);
253        
254        let s_i = &r_i * a + &r_j * b;
255        let s_j = &r_i * c + &r_j * d;
256
257        self.inner.set_column(i, &s_i);
258        self.inner.set_column(j, &s_j);
259    }
260}
261
262#[cfg(test)]
263mod tests { 
264    use super::*;
265
266    #[test]
267    fn init() { 
268        let a = Mat::from_data((2, 3), [1,2,3,4,5,6]);
269
270        assert_eq!(a.nrows(), 2);
271        assert_eq!(a.ncols(), 3);
272        assert_eq!(a.into_inner(), DMatrix::from_row_slice(2, 3, &[1,2,3,4,5,6]));
273    }
274
275    #[test]
276    fn eq() {
277        let a = Mat::from_data((2, 3), [1,2,3,4,5,6]);
278        let b = Mat::from_data((2, 3), [1,2,0,4,5,6]);
279        let c = Mat::from_data((3, 2), [1,2,3,4,5,6]);
280
281        assert_eq!(a, a);
282        assert_ne!(a, b);
283        assert_ne!(a, c);
284    }
285
286    #[test]
287    fn square() {
288        let a: Mat<i32> = Mat::zero((3, 3));
289        assert!(a.is_square());
290
291        let a: Mat<i32> = Mat::zero((3, 2));
292        assert!(!a.is_square());
293    }
294
295    #[test]
296    fn zero() {
297        let a: Mat<i32> = Mat::zero((3, 2));
298        assert!(a.is_zero());
299
300        let a = Mat::from_data((2, 3), [1,2,3,4,5,6]);
301        assert!(!a.is_zero());
302    }
303
304    #[test]
305    fn id() {
306        let a: Mat<i32> = Mat::id(3);
307        assert!(a.is_id());
308
309        let a = Mat::from_data((2, 2), [1,2,3,4]);
310        assert!(!a.is_id());
311
312        let a = Mat::from_data((2, 3), [1,0,0,0,1,0]);
313        assert!(!a.is_id());
314    }
315
316    #[test]
317    fn swap_rows() { 
318        let mut a = Mat::from_data((3, 4), 1..=12);
319        a.swap_rows(0, 1);
320        assert_eq!(a, Mat::from_data((3, 4), [5,6,7,8,1,2,3,4,9,10,11,12]));
321    }
322
323    #[test]
324    fn swap_cols() { 
325        let mut a = Mat::from_data((3, 4), 1..=12);
326        a.swap_cols(0, 1);
327        assert_eq!(a, Mat::from_data((3, 4), [2,1,3,4,6,5,7,8,10,9,11,12]));
328    }
329
330    #[test]
331    fn mul_row() { 
332        let mut a = Mat::from_data((3, 3), 1..=9);
333        a.mul_row(1, &10);
334        assert_eq!(a, Mat::from_data((3, 3), [1,2,3,40,50,60,7,8,9]));
335    }
336
337    #[test]
338    fn mul_col() { 
339        let mut a = Mat::from_data((3, 3), 1..=9);
340        a.mul_col(1, &10);
341        assert_eq!(a, Mat::from_data((3, 3), [1,20,3,4,50,6,7,80,9]));
342    }
343
344    #[test]
345    fn add_row_to() { 
346        let mut a = Mat::from_data((3, 3), 1..=9);
347        a.add_row_to(0, 1, &10);
348        assert_eq!(a, Mat::from_data((3, 3), [1,2,3,14,25,36,7,8,9]));
349    }
350
351    #[test]
352    fn add_col_to() { 
353        let mut a = Mat::from_data((3, 3), 1..=9);
354        a.add_col_to(0, 1, &10);
355        assert_eq!(a, Mat::from_data((3, 3), [1,12,3,4,45,6,7,78,9]));
356    }
357
358    #[test]
359    fn add() { 
360        let a = Mat::from_data((3, 2), [1,2,3,4,5,6]);
361        let b = Mat::from_data((3, 2), [8,2,4,0,2,1]);
362        let c = a + b;
363        assert_eq!(c, Mat::from_data((3, 2), [9,4,7,4,7,7]));
364    }
365
366    #[test]
367    fn sub() { 
368        let a = Mat::from_data((3, 2), [1,2,3,4,5,6]);
369        let b = Mat::from_data((3, 2), [8,2,4,0,2,1]);
370        let c = a - b;
371        assert_eq!(c, Mat::from_data((3, 2), [-7,0,-1,4,3,5]));
372    }
373
374    #[test]
375    fn neg() { 
376        let a = Mat::from_data((3, 2), [1,2,3,4,5,6]);
377        assert_eq!(-a, Mat::from_data((3, 2), [-1,-2,-3,-4,-5,-6]));
378    }
379
380    #[test]
381    fn mul() { 
382        let a = Mat::from_data((2, 3), [1,2,3,4,5,6]);
383        let b = Mat::from_data((3, 2), [1,2,1,-1,0,2]);
384        let c = a * b;
385        assert_eq!(c, Mat::from_data((2, 2), [3,6,9,15]));
386    }
387
388    #[test]
389    fn to_sparse() { 
390        let dns = Mat::from_data((2, 3), [1,2,3,4,5,6]);
391        let sps = dns.into_sparse();
392        assert_eq!(sps, SpMat::from_dense_data((2, 3), [1,2,3,4,5,6]));
393    }
394
395    #[test]
396    fn from_sparse() { 
397        let sps = SpMat::from_dense_data((2, 3), [1,2,3,4,5,6]);
398        let dns = Mat::from(sps);
399        assert_eq!(dns, Mat::from_data((2, 3), [1,2,3,4,5,6]));
400    }
401
402    #[test]
403    fn submat() { 
404        let a = Mat::from_data((3, 4), [
405            1, 2, 3, 7,
406            4, 5, 6, 8,
407            9,10,11,12           
408        ]);
409        let b = a.submat(1..3, 2..4);
410        assert_eq!(b, Mat::from_data((2, 2), [
411             6, 8,
412            11,12           
413        ]));
414    }
415}