1#[allow(clippy::needless_range_loop)]
17pub fn simul_eq_solve<const SIZE: usize, const RIGHT_COLS: usize>(
18 left: &[[f64; SIZE]; SIZE],
19 right: &[[f64; RIGHT_COLS]; SIZE],
20 result: &mut [[f64; RIGHT_COLS]; SIZE],
21) -> bool {
22 let cols = SIZE + RIGHT_COLS;
25 let mut tmp = vec![vec![0.0_f64; cols]; SIZE];
26
27 for i in 0..SIZE {
28 for j in 0..SIZE {
29 tmp[i][j] = left[i][j];
30 }
31 for j in 0..RIGHT_COLS {
32 tmp[i][SIZE + j] = right[i][j];
33 }
34 }
35
36 for k in 0..SIZE {
38 let mut pivot_row = k;
39 let mut max_val = -1.0_f64;
40 for i in k..SIZE {
41 let tmp_val = tmp[i][k].abs();
42 if tmp_val > max_val && tmp_val != 0.0 {
43 max_val = tmp_val;
44 pivot_row = i;
45 }
46 }
47 if tmp[pivot_row][k] == 0.0 {
48 return false; }
50 if pivot_row != k {
51 tmp.swap(pivot_row, k);
52 }
53
54 let a1 = tmp[k][k];
55 for j in k..cols {
56 tmp[k][j] /= a1;
57 }
58
59 for i in (k + 1)..SIZE {
60 let a1 = tmp[i][k];
61 for j in k..cols {
62 tmp[i][j] -= a1 * tmp[k][j];
63 }
64 }
65 }
66
67 for k in 0..RIGHT_COLS {
69 for m in (0..SIZE).rev() {
70 result[m][k] = tmp[m][SIZE + k];
71 for j in (m + 1)..SIZE {
72 result[m][k] -= tmp[m][j] * result[j][k];
73 }
74 }
75 }
76
77 true
78}
79
80#[cfg(test)]
85mod tests {
86 use super::*;
87
88 #[test]
89 fn test_identity_system() {
90 let left = [[1.0, 0.0], [0.0, 1.0]];
92 let right = [[3.0], [7.0]];
93 let mut result = [[0.0]; 2];
94 assert!(simul_eq_solve(&left, &right, &mut result));
95 assert!((result[0][0] - 3.0).abs() < 1e-10);
96 assert!((result[1][0] - 7.0).abs() < 1e-10);
97 }
98
99 #[test]
100 fn test_2x2_system() {
101 let left = [[2.0, 1.0], [1.0, 3.0]];
105 let right = [[5.0], [10.0]];
106 let mut result = [[0.0]; 2];
107 assert!(simul_eq_solve(&left, &right, &mut result));
108 assert!((result[0][0] - 1.0).abs() < 1e-10);
109 assert!((result[1][0] - 3.0).abs() < 1e-10);
110 }
111
112 #[test]
113 fn test_3x3_system() {
114 let left = [[1.0, 1.0, 1.0], [2.0, 1.0, -1.0], [1.0, -1.0, 1.0]];
119 let right = [[6.0], [1.0], [2.0]];
120 let mut result = [[0.0]; 3];
121 assert!(simul_eq_solve(&left, &right, &mut result));
122 assert!((result[0][0] - 1.0).abs() < 1e-10);
123 assert!((result[1][0] - 2.0).abs() < 1e-10);
124 assert!((result[2][0] - 3.0).abs() < 1e-10);
125 }
126
127 #[test]
128 fn test_singular_matrix() {
129 let left = [[1.0, 2.0], [2.0, 4.0]];
131 let right = [[3.0], [6.0]];
132 let mut result = [[0.0]; 2];
133 assert!(!simul_eq_solve(&left, &right, &mut result));
134 }
135
136 #[test]
137 fn test_multiple_right_columns() {
138 let left = [[1.0, 0.0], [0.0, 1.0]];
140 let right = [[1.0, 2.0], [3.0, 4.0]];
141 let mut result = [[0.0; 2]; 2];
142 assert!(simul_eq_solve(&left, &right, &mut result));
143 assert!((result[0][0] - 1.0).abs() < 1e-10);
144 assert!((result[0][1] - 2.0).abs() < 1e-10);
145 assert!((result[1][0] - 3.0).abs() < 1e-10);
146 assert!((result[1][1] - 4.0).abs() < 1e-10);
147 }
148
149 #[test]
150 fn test_4x4_system() {
151 let left = [
153 [1.0, 0.0, 0.0, 0.0],
154 [0.0, 1.0, 0.0, 0.0],
155 [0.0, 0.0, 1.0, 0.0],
156 [0.0, 0.0, 0.0, 1.0],
157 ];
158 let right = [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
159 let mut result = [[0.0; 2]; 4];
160 assert!(simul_eq_solve(&left, &right, &mut result));
161 assert!((result[0][0] - 5.0).abs() < 1e-10);
162 assert!((result[3][1] - 12.0).abs() < 1e-10);
163 }
164
165 #[test]
166 fn test_needs_pivoting() {
167 let left = [[0.0, 1.0], [1.0, 0.0]];
169 let right = [[3.0], [5.0]];
170 let mut result = [[0.0]; 2];
171 assert!(simul_eq_solve(&left, &right, &mut result));
172 assert!((result[0][0] - 5.0).abs() < 1e-10);
173 assert!((result[1][0] - 3.0).abs() < 1e-10);
174 }
175}