nexsys_math/
nxn.rs

1use super::linalg::*;
2
3/// An n x n matrix with a given `size` n and a `Vec` containing the variables in each column if applicable.
4#[derive(Clone)]
5#[derive(Debug)]
6#[derive(PartialEq)]
7pub struct NxN {
8    pub size: usize,
9    pub vars: Option<Vec<String>>,  // Optional header column for annotating which variables are 
10    mat: Vec<Vec<f64>>
11}
12impl NxN {
13
14    /// Initializes an NxN identity matrix of the specified size
15    /// # Example
16    /// ```
17    /// use nexsys_math::NxN;
18    /// 
19    /// let my_matrix = NxN::identity(3);
20    /// let check = vec![ 
21    ///     vec![1.0, 0.0, 0.0], 
22    ///     vec![0.0, 1.0, 0.0], 
23    ///     vec![0.0, 0.0, 1.0] 
24    /// ];
25    /// 
26    /// assert_eq!(my_matrix.to_vec(), check);
27    /// ```
28    pub fn identity(size: usize) -> NxN {
29        let mut mat = vec![];
30        for i in 0..size {
31            let mut col = vec![];
32            for j in 0..size {
33                if i == j {
34                    col.push(1_f64);
35                } else {
36                    col.push(0_f64);
37                }
38            }
39            mat.push(col);
40        }
41        NxN { size, mat, vars: None }
42    }
43
44    /// Initializes an NxN matrix of given values from a `Vec<Vec<f64>>`
45    /// # Example
46    /// ```
47    /// use nexsys_math::NxN;
48    /// 
49    /// let my_vars = vec!["x", "y", "z"];
50    /// let my_cols = vec![
51    ///     vec![1.0, 2.0, 3.0],
52    ///     vec![4.0, 5.0, 6.0],
53    ///     vec![7.0, 8.0, 9.0]
54    /// ];
55    ///  
56    /// let my_matrix = NxN::from_cols(
57    ///     my_cols.clone(), 
58    ///     Some(my_vars)
59    /// ).unwrap();
60    /// 
61    /// assert_eq!(my_matrix.to_vec(), my_cols);
62    /// ```
63    pub fn from_cols<T>(cols: Vec<Vec<T>>, col_vars: Option<Vec<&str>>) -> Result<NxN, &'static str>
64    where
65        T: Into<f64> + Copy
66    {
67        let mut vars = None;
68
69        if let Some(v) = col_vars {
70            vars = Some(v.iter().map(|&i| i.to_string()).collect());
71        }
72
73        if cols.len() != cols[0].len() {
74            return Err("given vectors do not form an NxN matrix")
75        } else {
76            let size = cols.len();
77            let mat = cols.iter().map(
78                |i| {
79                    i.iter()
80                    .map(|&j| j.into())
81                    .collect()
82                }
83            ).collect();
84            Ok(NxN { size, vars, mat })
85        }
86    }
87
88    /// Mutates a row, scaling it by the given value
89    /// # Example
90    /// ```
91    /// use nexsys_math::NxN;
92    /// 
93    /// let mut my_matrix = NxN::identity(3);
94    /// 
95    /// let check = vec![ 
96    ///     vec![1.0, 0.0, 0.0], 
97    ///     vec![0.0, 2.0, 0.0], 
98    ///     vec![0.0, 0.0, 1.0] 
99    /// ];
100    /// 
101    /// my_matrix.scale_row(1, 2);
102    /// 
103    /// assert_eq!(my_matrix.to_vec(), check);
104    /// ```
105    pub fn scale_row<T>(&mut self, row: usize, scalar: T)
106    where
107        T: Into<f64> + Copy
108    { 
109        let n = self.size;
110        for i in 0..n {
111            self.mat[i][row] *= scalar.into();
112        }
113    }
114
115    /// Adds a given row vector to a row in the matrix
116    /// # Example
117    /// ```
118    /// use nexsys_math::NxN;
119    /// 
120    /// let mut my_matrix = NxN::identity(3);
121    /// let check = vec![ 
122    ///     vec![1.0, 2.0, 0.0], 
123    ///     vec![0.0, 3.0, 0.0], 
124    ///     vec![0.0, 2.0, 1.0] 
125    /// ];
126    /// my_matrix.add_to_row(1, &vec![2, 2, 2]);
127    /// assert_eq!(my_matrix.to_vec(), check);
128    /// ```
129    pub fn add_to_row<T>(&mut self, row: usize, vec: &Vec<T>) 
130    where 
131        T: Into<f64> + Copy
132    {
133        let n = self.size;
134        for i in 0..n {
135            self.mat[i][row] += vec[i].into();
136        }
137    }
138
139    /// Returns a row from the matrix
140    /// # Example
141    /// ```
142    /// use nexsys_math::NxN;
143    /// 
144    /// let mut my_matrix = NxN::identity(3);
145    /// 
146    /// let check = vec![0.0, 0.0, 1.0];
147    /// 
148    /// assert_eq!(my_matrix.get_row(2), check);
149    /// ```
150    pub fn get_row(&self, row: usize) -> Vec<f64> {
151        let n = self.size;
152        let mut res = vec![];
153        for i in 0..n {
154            res.push(self.mat[i][row]);
155        }
156        res
157    }
158
159    fn invert_2x2(&mut self) -> Result<(), &'static str> {
160        
161        let m = &self.mat;
162        
163        let m11 = m[0][0];
164        let m12 = m[1][0];
165        let m21 = m[0][1];
166        let m22 = m[1][1];
167
168        let det = m11*m22 - m12*m21;
169
170        if det == 0_f64 {
171            return Err("matrix is non-invertible! determinant is 0!")
172        }
173    
174        self.mat = vec![
175            vec![ // column 1
176                m22/det, 
177                -m21/det
178            ],
179            vec![ // column 2
180                -m12/det,  
181                m11/det
182            ]
183        ];
184
185        Ok(())    
186    }
187
188    fn invert_3x3(&mut self) -> Result<(), &'static str> {
189
190        let m = &self.mat;
191        let m11 = m[0][0];
192        let m12 = m[1][0];
193        let m13 = m[2][0];
194        let m21 = m[0][1];
195        let m22 = m[1][1];
196        let m23 = m[2][1];
197        let m31 = m[0][2];
198        let m32 = m[1][2];
199        let m33 = m[2][2];
200
201        let det:f64 = m11*m22*m33 + m21*m32*m13 + m31*m12*m23 - m11*m32*m23 - m31*m22*m13 - m21*m12*m33;
202
203        if det == 0_f64 {
204            return Err("matrix is non-invertible! determinant is 0!")
205        }
206
207        self.mat = vec![
208            vec![ // column 1
209                (m22*m33 - m23*m32)/det, 
210                (m23*m31 - m21*m33)/det, 
211                (m21*m32 - m22*m31)/det
212            ],
213            vec![ // column 2
214                (m13*m32 - m12*m33)/det,
215                (m11*m33 - m13*m31)/det,
216                (m12*m31 - m11*m32)/det
217            ],
218            vec![ // column 3
219                (m12*m23 - m13*m22)/det,
220                (m13*m21 - m11*m23)/det,
221                (m11*m22 - m12*m21)/det 
222            ],
223        ];
224
225        Ok(())
226    }
227
228    fn invert_4x4(&mut self) -> Result<(), &'static str> {
229        let m = &self.mat;
230        
231        let a11 = m[0][0];
232        let a12 = m[1][0];
233        let a13 = m[2][0];
234        let a14 = m[3][0];
235        let a21 = m[0][1];
236        let a22 = m[1][1];
237        let a23 = m[2][1];
238        let a24 = m[3][1];
239        let a31 = m[0][2];
240        let a32 = m[1][2];
241        let a33 = m[2][2];
242        let a34 = m[3][2];
243        let a41 = m[0][3];
244        let a42 = m[1][3];
245        let a43 = m[2][3];
246        let a44 = m[3][3];
247
248        let det: f64 =  a11*a22*a33*a44 + a11*a23*a34*a42 + a11*a24*a32*a43 +
249                        a12*a21*a34*a43 + a12*a23*a31*a44 + a12*a24*a33*a41 + 
250                        a13*a21*a32*a44 + a13*a22*a34*a41 + a13*a24*a31*a42 + 
251                        a14*a21*a33*a42 + a14*a22*a34*a43 + a14*a23*a32*a41 -
252                        a11*a22*a34*a43 - a11*a23*a32*a44 - a11*a24*a33*a42 -
253                        a12*a21*a33*a44 - a12*a23*a34*a41 - a12*a24*a31*a43 -
254                        a13*a21*a34*a42 - a13*a22*a31*a44 - a13*a24*a32*a41 -
255                        a14*a21*a32*a43 - a14*a22*a33*a41 - a14*a23*a31*a42;
256                        
257        if det == 0_f64 {
258            return Err("matrix is non-invertible! determinant is 0!")
259        }
260
261        let b11 = (a22*a33*a44 + a23*a34*a42 + a24*a32*a43 - a22*a34*a43 - a23*a32*a44 - a24*a33*a42) / det;
262        let b12 = (a12*a34*a43 + a13*a32*a44 + a14*a33*a42 - a12*a33*a44 - a13*a34*a42 - a14*a32*a43) / det;
263        let b13 = (a12*a23*a44 + a13*a24*a42 + a14*a22*a43 - a12*a24*a43 - a13*a22*a44 - a14*a23*a42) / det;
264        let b14 = (a12*a24*a33 + a13*a22*a34 + a14*a23*a32 - a12*a23*a34 - a13*a24*a32 - a14*a22*a33) / det;
265        let b21 = (a21*a34*a43 + a23*a31*a44 + a24*a33*a41 - a21*a33*a44 - a23*a34*a41 - a24*a31*a43) / det;
266        let b22 = (a11*a33*a44 + a13*a34*a41 + a14*a31*a43 - a11*a34*a43 - a13*a31*a44 - a14*a33*a41) / det;
267        let b23 = (a11*a24*a43 + a13*a21*a44 + a14*a23*a41 - a11*a23*a44 - a13*a24*a41 - a14*a21*a43) / det;
268        let b24 = (a11*a23*a34 + a13*a24*a31 + a14*a21*a33 - a11*a24*a33 - a13*a21*a34 - a14*a23*a31) / det;
269        let b31 = (a21*a32*a44 + a22*a34*a41 + a24*a31*a42 - a21*a34*a42 - a22*a31*a44 - a24*a32*a41) / det;
270        let b32 = (a11*a34*a42 + a12*a31*a44 + a14*a32*a41 - a11*a32*a44 - a12*a34*a41 - a14*a31*a42) / det;
271        let b33 = (a11*a22*a44 + a12*a24*a41 + a14*a21*a42 - a11*a24*a42 - a12*a21*a44 - a14*a22*a41) / det;
272        let b34 = (a11*a24*a32 + a12*a21*a34 + a14*a22*a31 - a11*a22*a34 - a12*a24*a31 - a14*a21*a32) / det;
273        let b41 = (a21*a33*a42 + a22*a31*a43 + a23*a32*a41 - a21*a32*a43 - a22*a33*a41 - a23*a31*a42) / det;
274        let b42 = (a11*a32*a43 + a12*a33*a41 + a13*a31*a42 - a11*a33*a42 - a12*a31*a43 - a13*a32*a41) / det;
275        let b43 = (a11*a23*a42 + a12*a21*a43 + a13*a22*a41 - a11*a22*a43 - a12*a23*a41 - a13*a21*a42) / det;
276        let b44 = (a11*a22*a33 + a12*a23*a31 + a13*a21*a32 - a11*a23*a32 - a12*a21*a33 - a13*a22*a31) / det;
277
278        self.mat = vec![
279            vec![b11, b21, b31, b41],
280            vec![b12, b22, b32, b42],
281            vec![b13, b23, b33, b43],
282            vec![b14, b24, b34, b44],     
283        ];
284        
285        Ok(())
286    }
287
288    fn invert_nxn(&mut self) -> Result<(), &'static str> {
289        let n = self.size;
290        let mut inv = NxN::identity(n);
291
292        for c in 0..n {
293            for r in 0..n {
294                if c == r {
295                    continue; // guard clause against modifying the diagonal
296                } else {
297                    if self.mat[c][c] == 0_f64 { 
298                        return Err("division by zero encountered during matrix inversion")
299                    }
300                    // get the scalar that needs to be applied to the row vector
301                    let scalar = - self.mat[c][r] / self.mat[c][c];
302
303                    // create the row vector to add to self & row vector to add to inv
304                    let v = scale_vec(self.get_row(c), scalar);
305                    let vi = scale_vec(inv.get_row(c), scalar);
306
307                    self.add_to_row(r, &v); // add the vector to self
308                    inv.add_to_row(r, &vi); // perform the same operation on the identity matrix
309                }
310            }
311        }
312
313        for i in 0..n {
314            let scalar = 1.0 / self.mat[i][i];
315            self.scale_row(i, scalar);
316            inv.scale_row(i, scalar);
317        }
318
319        // println!("{:?}", self.mat);
320
321        // Assign the identity matrix's values to self.mat
322        self.mat = inv.to_vec();
323        Ok(())
324    }
325
326    /// inverts the matrix, if possible. This method returns a result that
327    /// indicates whether the inversion was successful or not.
328    /// # Example
329    /// ```
330    /// use nexsys_math::NxN;
331    /// 
332    /// let mut my_matrix = NxN::from_cols(vec![ 
333    ///    vec![-1.0, 1.0], 
334    ///    vec![ 1.5,-1.0] 
335    /// ], None).unwrap();
336    /// 
337    /// my_matrix.invert().unwrap();
338    /// 
339    /// let inverse = vec![ 
340    ///     vec![2.0, 2.0], 
341    ///     vec![3.0, 2.0] 
342    /// ];
343    /// 
344    /// assert_eq!(my_matrix.to_vec(), inverse);
345    /// ```
346    pub fn invert(&mut self) -> Result<(), &'static str> {
347
348        // Different inversion methods are chosen to mitigate 
349        // computational expense.
350        if self.size == 2 {
351
352            Ok(self.invert_2x2()?)
353        
354        } else if self.size == 3 {
355
356            Ok(self.invert_3x3()?)
357
358        } else if self.size == 4 {
359
360            Ok(self.invert_4x4()?)
361
362        } else {
363        
364            Ok(self.invert_nxn()?)
365        
366        }
367
368    }
369
370    /// Returns the matrix as `Vec<Vec<f64>>`, consuming the `self` value in the process
371    pub fn to_vec(self) -> Vec<Vec<f64>> {
372        self.mat
373    }
374}