bem/core/solver/
direct.rs

1//! Direct solver (LU factorization)
2//!
3//! Uses ndarray-linalg for LU decomposition and solve.
4
5use ndarray::{Array1, Array2};
6use num_complex::Complex64;
7
8/// Direct solver result
9#[derive(Debug)]
10pub struct DirectSolution {
11    /// Solution vector
12    pub x: Array1<Complex64>,
13    /// Whether the solve was successful
14    pub success: bool,
15}
16
17/// Solve Ax = b using LU factorization
18///
19/// # Arguments
20/// * `a` - Coefficient matrix (n × n)
21/// * `b` - Right-hand side vector (n)
22///
23/// # Returns
24/// Solution struct containing x and success status
25///
26/// # Example
27/// ```ignore
28/// let solution = direct_solve(&matrix, &rhs);
29/// if solution.success {
30///     println!("Solution: {:?}", solution.x);
31/// }
32/// ```
33pub fn direct_solve(a: &Array2<Complex64>, b: &Array1<Complex64>) -> DirectSolution {
34    use ndarray_linalg::Solve;
35
36    match a.solve(b) {
37        Ok(x) => DirectSolution { x, success: true },
38        Err(_) => DirectSolution {
39            x: Array1::zeros(b.len()),
40            success: false,
41        },
42    }
43}
44
45/// Solve Ax = b using LU factorization with pivoting
46///
47/// This is an alternative interface that's equivalent to direct_solve.
48pub fn direct_solve_lu(a: &Array2<Complex64>, b: &Array1<Complex64>) -> DirectSolution {
49    // Just delegate to the main solve function
50    direct_solve(a, b)
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn test_direct_solve_simple() {
59        // Simple 2x2 system
60        let a = Array2::from_shape_vec(
61            (2, 2),
62            vec![
63                Complex64::new(2.0, 0.0),
64                Complex64::new(1.0, 0.0),
65                Complex64::new(1.0, 0.0),
66                Complex64::new(3.0, 0.0),
67            ],
68        )
69        .unwrap();
70
71        let b = Array1::from_vec(vec![Complex64::new(4.0, 0.0), Complex64::new(5.0, 0.0)]);
72
73        let solution = direct_solve(&a, &b);
74
75        assert!(solution.success);
76
77        // Verify: Ax ≈ b
78        let ax = a.dot(&solution.x);
79        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
80        assert!(error < 1e-10);
81    }
82
83    #[test]
84    fn test_direct_solve_complex() {
85        // Complex system
86        let a = Array2::from_shape_vec(
87            (2, 2),
88            vec![
89                Complex64::new(1.0, 1.0),
90                Complex64::new(0.0, 1.0),
91                Complex64::new(1.0, 0.0),
92                Complex64::new(1.0, -1.0),
93            ],
94        )
95        .unwrap();
96
97        let b = Array1::from_vec(vec![Complex64::new(2.0, 1.0), Complex64::new(1.0, 0.0)]);
98
99        let solution = direct_solve(&a, &b);
100
101        assert!(solution.success);
102
103        // Verify: Ax ≈ b
104        let ax = a.dot(&solution.x);
105        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
106        assert!(error < 1e-10);
107    }
108
109    #[test]
110    fn test_direct_solve_identity() {
111        let n = 5;
112        let mut a: Array2<Complex64> = Array2::zeros((n, n));
113        for i in 0..n {
114            a[[i, i]] = Complex64::new(1.0, 0.0);
115        }
116
117        let b = Array1::from_vec(
118            (1..=n)
119                .map(|i| Complex64::new(i as f64, 0.0))
120                .collect::<Vec<_>>(),
121        );
122
123        let solution = direct_solve(&a, &b);
124
125        assert!(solution.success);
126
127        // x should equal b for identity matrix
128        let error: f64 = (&solution.x - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
129        assert!(error < 1e-10);
130    }
131}