Skip to main content

agg_rust/
simul_eq.rs

1//! Solving simultaneous equations via Gaussian elimination.
2//!
3//! Port of `agg_simul_eq.h` — solves systems of linear equations using
4//! Gaussian elimination with partial pivoting.
5
6// ============================================================================
7// Simultaneous equation solver
8// ============================================================================
9
10/// Solve the system `left * X = right` for `X`.
11///
12/// Uses Gaussian elimination with partial pivoting.
13/// Returns `true` if successful, `false` if the matrix is singular.
14///
15/// Port of C++ `simul_eq<Size, RightCols>::solve`.
16#[allow(clippy::needless_range_loop)]
17pub fn simul_eq_solve<const SIZE: usize, const RIGHT_COLS: usize>(
18    left: &[[f64; SIZE]; SIZE],
19    right: &[[f64; RIGHT_COLS]; SIZE],
20    result: &mut [[f64; RIGHT_COLS]; SIZE],
21) -> bool {
22    // Build augmented matrix [left | right] using Vec (const generic arithmetic
23    // not supported in stable Rust)
24    let cols = SIZE + RIGHT_COLS;
25    let mut tmp = vec![vec![0.0_f64; cols]; SIZE];
26
27    for i in 0..SIZE {
28        for j in 0..SIZE {
29            tmp[i][j] = left[i][j];
30        }
31        for j in 0..RIGHT_COLS {
32            tmp[i][SIZE + j] = right[i][j];
33        }
34    }
35
36    // Forward elimination with partial pivoting
37    for k in 0..SIZE {
38        let mut pivot_row = k;
39        let mut max_val = -1.0_f64;
40        for i in k..SIZE {
41            let tmp_val = tmp[i][k].abs();
42            if tmp_val > max_val && tmp_val != 0.0 {
43                max_val = tmp_val;
44                pivot_row = i;
45            }
46        }
47        if tmp[pivot_row][k] == 0.0 {
48            return false; // Singular
49        }
50        if pivot_row != k {
51            tmp.swap(pivot_row, k);
52        }
53
54        let a1 = tmp[k][k];
55        for j in k..cols {
56            tmp[k][j] /= a1;
57        }
58
59        for i in (k + 1)..SIZE {
60            let a1 = tmp[i][k];
61            for j in k..cols {
62                tmp[i][j] -= a1 * tmp[k][j];
63            }
64        }
65    }
66
67    // Back substitution
68    for k in 0..RIGHT_COLS {
69        for m in (0..SIZE).rev() {
70            result[m][k] = tmp[m][SIZE + k];
71            for j in (m + 1)..SIZE {
72                result[m][k] -= tmp[m][j] * result[j][k];
73            }
74        }
75    }
76
77    true
78}
79
80// ============================================================================
81// Tests
82// ============================================================================
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn test_identity_system() {
90        // I * x = b => x = b
91        let left = [[1.0, 0.0], [0.0, 1.0]];
92        let right = [[3.0], [7.0]];
93        let mut result = [[0.0]; 2];
94        assert!(simul_eq_solve(&left, &right, &mut result));
95        assert!((result[0][0] - 3.0).abs() < 1e-10);
96        assert!((result[1][0] - 7.0).abs() < 1e-10);
97    }
98
99    #[test]
100    fn test_2x2_system() {
101        // 2x + y = 5
102        // x + 3y = 10
103        // Solution: x = 1, y = 3
104        let left = [[2.0, 1.0], [1.0, 3.0]];
105        let right = [[5.0], [10.0]];
106        let mut result = [[0.0]; 2];
107        assert!(simul_eq_solve(&left, &right, &mut result));
108        assert!((result[0][0] - 1.0).abs() < 1e-10);
109        assert!((result[1][0] - 3.0).abs() < 1e-10);
110    }
111
112    #[test]
113    fn test_3x3_system() {
114        // x + y + z = 6
115        // 2x + y - z = 1
116        // x - y + z = 2
117        // Solution: x = 1, y = 2, z = 3
118        let left = [[1.0, 1.0, 1.0], [2.0, 1.0, -1.0], [1.0, -1.0, 1.0]];
119        let right = [[6.0], [1.0], [2.0]];
120        let mut result = [[0.0]; 3];
121        assert!(simul_eq_solve(&left, &right, &mut result));
122        assert!((result[0][0] - 1.0).abs() < 1e-10);
123        assert!((result[1][0] - 2.0).abs() < 1e-10);
124        assert!((result[2][0] - 3.0).abs() < 1e-10);
125    }
126
127    #[test]
128    fn test_singular_matrix() {
129        // Singular: rows are linearly dependent
130        let left = [[1.0, 2.0], [2.0, 4.0]];
131        let right = [[3.0], [6.0]];
132        let mut result = [[0.0]; 2];
133        assert!(!simul_eq_solve(&left, &right, &mut result));
134    }
135
136    #[test]
137    fn test_multiple_right_columns() {
138        // Solve for two RHS vectors simultaneously
139        let left = [[1.0, 0.0], [0.0, 1.0]];
140        let right = [[1.0, 2.0], [3.0, 4.0]];
141        let mut result = [[0.0; 2]; 2];
142        assert!(simul_eq_solve(&left, &right, &mut result));
143        assert!((result[0][0] - 1.0).abs() < 1e-10);
144        assert!((result[0][1] - 2.0).abs() < 1e-10);
145        assert!((result[1][0] - 3.0).abs() < 1e-10);
146        assert!((result[1][1] - 4.0).abs() < 1e-10);
147    }
148
149    #[test]
150    fn test_4x4_system() {
151        // 4x4 system (used by parl_to_parl in trans_affine)
152        let left = [
153            [1.0, 0.0, 0.0, 0.0],
154            [0.0, 1.0, 0.0, 0.0],
155            [0.0, 0.0, 1.0, 0.0],
156            [0.0, 0.0, 0.0, 1.0],
157        ];
158        let right = [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
159        let mut result = [[0.0; 2]; 4];
160        assert!(simul_eq_solve(&left, &right, &mut result));
161        assert!((result[0][0] - 5.0).abs() < 1e-10);
162        assert!((result[3][1] - 12.0).abs() < 1e-10);
163    }
164
165    #[test]
166    fn test_needs_pivoting() {
167        // First row has zero in pivot position, requires row swap
168        let left = [[0.0, 1.0], [1.0, 0.0]];
169        let right = [[3.0], [5.0]];
170        let mut result = [[0.0]; 2];
171        assert!(simul_eq_solve(&left, &right, &mut result));
172        assert!((result[0][0] - 5.0).abs() < 1e-10);
173        assert!((result[1][0] - 3.0).abs() < 1e-10);
174    }
175}