1use super::linalg::*;
2
3#[derive(Clone)]
5#[derive(Debug)]
6#[derive(PartialEq)]
7pub struct NxN {
8 pub size: usize,
9 pub vars: Option<Vec<String>>, mat: Vec<Vec<f64>>
11}
12impl NxN {
13
14 pub fn identity(size: usize) -> NxN {
29 let mut mat = vec![];
30 for i in 0..size {
31 let mut col = vec![];
32 for j in 0..size {
33 if i == j {
34 col.push(1_f64);
35 } else {
36 col.push(0_f64);
37 }
38 }
39 mat.push(col);
40 }
41 NxN { size, mat, vars: None }
42 }
43
44 pub fn from_cols<T>(cols: Vec<Vec<T>>, col_vars: Option<Vec<&str>>) -> Result<NxN, &'static str>
64 where
65 T: Into<f64> + Copy
66 {
67 let mut vars = None;
68
69 if let Some(v) = col_vars {
70 vars = Some(v.iter().map(|&i| i.to_string()).collect());
71 }
72
73 if cols.len() != cols[0].len() {
74 return Err("given vectors do not form an NxN matrix")
75 } else {
76 let size = cols.len();
77 let mat = cols.iter().map(
78 |i| {
79 i.iter()
80 .map(|&j| j.into())
81 .collect()
82 }
83 ).collect();
84 Ok(NxN { size, vars, mat })
85 }
86 }
87
88 pub fn scale_row<T>(&mut self, row: usize, scalar: T)
106 where
107 T: Into<f64> + Copy
108 {
109 let n = self.size;
110 for i in 0..n {
111 self.mat[i][row] *= scalar.into();
112 }
113 }
114
115 pub fn add_to_row<T>(&mut self, row: usize, vec: &Vec<T>)
130 where
131 T: Into<f64> + Copy
132 {
133 let n = self.size;
134 for i in 0..n {
135 self.mat[i][row] += vec[i].into();
136 }
137 }
138
139 pub fn get_row(&self, row: usize) -> Vec<f64> {
151 let n = self.size;
152 let mut res = vec![];
153 for i in 0..n {
154 res.push(self.mat[i][row]);
155 }
156 res
157 }
158
159 fn invert_2x2(&mut self) -> Result<(), &'static str> {
160
161 let m = &self.mat;
162
163 let m11 = m[0][0];
164 let m12 = m[1][0];
165 let m21 = m[0][1];
166 let m22 = m[1][1];
167
168 let det = m11*m22 - m12*m21;
169
170 if det == 0_f64 {
171 return Err("matrix is non-invertible! determinant is 0!")
172 }
173
174 self.mat = vec![
175 vec![ m22/det,
177 -m21/det
178 ],
179 vec![ -m12/det,
181 m11/det
182 ]
183 ];
184
185 Ok(())
186 }
187
188 fn invert_3x3(&mut self) -> Result<(), &'static str> {
189
190 let m = &self.mat;
191 let m11 = m[0][0];
192 let m12 = m[1][0];
193 let m13 = m[2][0];
194 let m21 = m[0][1];
195 let m22 = m[1][1];
196 let m23 = m[2][1];
197 let m31 = m[0][2];
198 let m32 = m[1][2];
199 let m33 = m[2][2];
200
201 let det:f64 = m11*m22*m33 + m21*m32*m13 + m31*m12*m23 - m11*m32*m23 - m31*m22*m13 - m21*m12*m33;
202
203 if det == 0_f64 {
204 return Err("matrix is non-invertible! determinant is 0!")
205 }
206
207 self.mat = vec![
208 vec![ (m22*m33 - m23*m32)/det,
210 (m23*m31 - m21*m33)/det,
211 (m21*m32 - m22*m31)/det
212 ],
213 vec![ (m13*m32 - m12*m33)/det,
215 (m11*m33 - m13*m31)/det,
216 (m12*m31 - m11*m32)/det
217 ],
218 vec![ (m12*m23 - m13*m22)/det,
220 (m13*m21 - m11*m23)/det,
221 (m11*m22 - m12*m21)/det
222 ],
223 ];
224
225 Ok(())
226 }
227
228 fn invert_4x4(&mut self) -> Result<(), &'static str> {
229 let m = &self.mat;
230
231 let a11 = m[0][0];
232 let a12 = m[1][0];
233 let a13 = m[2][0];
234 let a14 = m[3][0];
235 let a21 = m[0][1];
236 let a22 = m[1][1];
237 let a23 = m[2][1];
238 let a24 = m[3][1];
239 let a31 = m[0][2];
240 let a32 = m[1][2];
241 let a33 = m[2][2];
242 let a34 = m[3][2];
243 let a41 = m[0][3];
244 let a42 = m[1][3];
245 let a43 = m[2][3];
246 let a44 = m[3][3];
247
248 let det: f64 = a11*a22*a33*a44 + a11*a23*a34*a42 + a11*a24*a32*a43 +
249 a12*a21*a34*a43 + a12*a23*a31*a44 + a12*a24*a33*a41 +
250 a13*a21*a32*a44 + a13*a22*a34*a41 + a13*a24*a31*a42 +
251 a14*a21*a33*a42 + a14*a22*a34*a43 + a14*a23*a32*a41 -
252 a11*a22*a34*a43 - a11*a23*a32*a44 - a11*a24*a33*a42 -
253 a12*a21*a33*a44 - a12*a23*a34*a41 - a12*a24*a31*a43 -
254 a13*a21*a34*a42 - a13*a22*a31*a44 - a13*a24*a32*a41 -
255 a14*a21*a32*a43 - a14*a22*a33*a41 - a14*a23*a31*a42;
256
257 if det == 0_f64 {
258 return Err("matrix is non-invertible! determinant is 0!")
259 }
260
261 let b11 = (a22*a33*a44 + a23*a34*a42 + a24*a32*a43 - a22*a34*a43 - a23*a32*a44 - a24*a33*a42) / det;
262 let b12 = (a12*a34*a43 + a13*a32*a44 + a14*a33*a42 - a12*a33*a44 - a13*a34*a42 - a14*a32*a43) / det;
263 let b13 = (a12*a23*a44 + a13*a24*a42 + a14*a22*a43 - a12*a24*a43 - a13*a22*a44 - a14*a23*a42) / det;
264 let b14 = (a12*a24*a33 + a13*a22*a34 + a14*a23*a32 - a12*a23*a34 - a13*a24*a32 - a14*a22*a33) / det;
265 let b21 = (a21*a34*a43 + a23*a31*a44 + a24*a33*a41 - a21*a33*a44 - a23*a34*a41 - a24*a31*a43) / det;
266 let b22 = (a11*a33*a44 + a13*a34*a41 + a14*a31*a43 - a11*a34*a43 - a13*a31*a44 - a14*a33*a41) / det;
267 let b23 = (a11*a24*a43 + a13*a21*a44 + a14*a23*a41 - a11*a23*a44 - a13*a24*a41 - a14*a21*a43) / det;
268 let b24 = (a11*a23*a34 + a13*a24*a31 + a14*a21*a33 - a11*a24*a33 - a13*a21*a34 - a14*a23*a31) / det;
269 let b31 = (a21*a32*a44 + a22*a34*a41 + a24*a31*a42 - a21*a34*a42 - a22*a31*a44 - a24*a32*a41) / det;
270 let b32 = (a11*a34*a42 + a12*a31*a44 + a14*a32*a41 - a11*a32*a44 - a12*a34*a41 - a14*a31*a42) / det;
271 let b33 = (a11*a22*a44 + a12*a24*a41 + a14*a21*a42 - a11*a24*a42 - a12*a21*a44 - a14*a22*a41) / det;
272 let b34 = (a11*a24*a32 + a12*a21*a34 + a14*a22*a31 - a11*a22*a34 - a12*a24*a31 - a14*a21*a32) / det;
273 let b41 = (a21*a33*a42 + a22*a31*a43 + a23*a32*a41 - a21*a32*a43 - a22*a33*a41 - a23*a31*a42) / det;
274 let b42 = (a11*a32*a43 + a12*a33*a41 + a13*a31*a42 - a11*a33*a42 - a12*a31*a43 - a13*a32*a41) / det;
275 let b43 = (a11*a23*a42 + a12*a21*a43 + a13*a22*a41 - a11*a22*a43 - a12*a23*a41 - a13*a21*a42) / det;
276 let b44 = (a11*a22*a33 + a12*a23*a31 + a13*a21*a32 - a11*a23*a32 - a12*a21*a33 - a13*a22*a31) / det;
277
278 self.mat = vec![
279 vec![b11, b21, b31, b41],
280 vec![b12, b22, b32, b42],
281 vec![b13, b23, b33, b43],
282 vec![b14, b24, b34, b44],
283 ];
284
285 Ok(())
286 }
287
288 fn invert_nxn(&mut self) -> Result<(), &'static str> {
289 let n = self.size;
290 let mut inv = NxN::identity(n);
291
292 for c in 0..n {
293 for r in 0..n {
294 if c == r {
295 continue; } else {
297 if self.mat[c][c] == 0_f64 {
298 return Err("division by zero encountered during matrix inversion")
299 }
300 let scalar = - self.mat[c][r] / self.mat[c][c];
302
303 let v = scale_vec(self.get_row(c), scalar);
305 let vi = scale_vec(inv.get_row(c), scalar);
306
307 self.add_to_row(r, &v); inv.add_to_row(r, &vi); }
310 }
311 }
312
313 for i in 0..n {
314 let scalar = 1.0 / self.mat[i][i];
315 self.scale_row(i, scalar);
316 inv.scale_row(i, scalar);
317 }
318
319 self.mat = inv.to_vec();
323 Ok(())
324 }
325
326 pub fn invert(&mut self) -> Result<(), &'static str> {
347
348 if self.size == 2 {
351
352 Ok(self.invert_2x2()?)
353
354 } else if self.size == 3 {
355
356 Ok(self.invert_3x3()?)
357
358 } else if self.size == 4 {
359
360 Ok(self.invert_4x4()?)
361
362 } else {
363
364 Ok(self.invert_nxn()?)
365
366 }
367
368 }
369
370 pub fn to_vec(self) -> Vec<Vec<f64>> {
372 self.mat
373 }
374}