use crate::{
linalg::Matrix,
traits::{Real, State},
};
pub fn lin_solve<T: Real, Y: State<T>>(a: &Matrix<T>, b: &mut Y, ip: &[usize]) {
let n = a.nrows();
debug_assert_eq!(b.len(), n, "RHS length must match matrix size");
if n == 1 {
let x = b.get(0) / a[(0, 0)];
b.set(0, x);
return;
}
let nm1 = n - 1;
for k in 0..nm1 {
let kp1 = k + 1;
let m = ip[k];
let t = b.get(m);
let bk = b.get(k);
b.set(m, bk);
b.set(k, t);
for i in kp1..n {
let bi = b.get(i) + a[(i, k)] * t;
b.set(i, bi);
}
}
for kb in 1..n {
let k = n - kb;
let xk = b.get(k) / a[(k, k)];
b.set(k, xk);
let t = -xk;
for i in 0..k {
let bi = b.get(i) + a[(i, k)] * t;
b.set(i, bi);
}
}
let x0 = b.get(0) / a[(0, 0)];
b.set(0, x0);
}
pub fn lin_solve_complex<T: Real, Y: State<T>>(
ar: &Matrix<T>,
ai: &Matrix<T>,
br: &mut Y,
bi: &mut Y,
ip: &[usize],
) {
let n = ar.nrows();
debug_assert_eq!(br.len(), n, "RHS length must match matrix size");
debug_assert_eq!(bi.len(), n, "RHS length must match matrix size");
if n == 1 {
let den = ar[(0, 0)] * ar[(0, 0)] + ai[(0, 0)] * ai[(0, 0)];
let temp_r = (br.get(0) * ar[(0, 0)] + bi.get(0) * ai[(0, 0)]) / den;
let temp_i = (bi.get(0) * ar[(0, 0)] - br.get(0) * ai[(0, 0)]) / den;
br.set(0, temp_r);
bi.set(0, temp_i);
return;
}
let nm1 = n - 1;
for k in 0..nm1 {
let kp1 = k + 1;
let m = ip[k];
let tr = br.get(m);
let ti = bi.get(m);
let brk = br.get(k);
let bik = bi.get(k);
br.set(m, brk);
bi.set(m, bik);
br.set(k, tr);
bi.set(k, ti);
for i in kp1..n {
let prod_r = ar[(i, k)] * tr - ai[(i, k)] * ti;
let prod_i = ai[(i, k)] * tr + ar[(i, k)] * ti;
let bir = br.get(i) + prod_r;
let bii = bi.get(i) + prod_i;
br.set(i, bir);
bi.set(i, bii);
}
}
for kb in 1..n {
let k = n - kb;
let den = ar[(k, k)] * ar[(k, k)] + ai[(k, k)] * ai[(k, k)];
let temp_r = (br.get(k) * ar[(k, k)] + bi.get(k) * ai[(k, k)]) / den;
let temp_i = (bi.get(k) * ar[(k, k)] - br.get(k) * ai[(k, k)]) / den;
br.set(k, temp_r);
bi.set(k, temp_i);
let tr = -br.get(k);
let ti = -bi.get(k);
for i in 0..k {
let prod_r = ar[(i, k)] * tr - ai[(i, k)] * ti;
let prod_i = ai[(i, k)] * tr + ar[(i, k)] * ti;
let bir = br.get(i) + prod_r;
let bii = bi.get(i) + prod_i;
br.set(i, bir);
bi.set(i, bii);
}
}
let den = ar[(0, 0)] * ar[(0, 0)] + ai[(0, 0)] * ai[(0, 0)];
let temp_r = (br.get(0) * ar[(0, 0)] + bi.get(0) * ai[(0, 0)]) / den;
let temp_i = (bi.get(0) * ar[(0, 0)] - br.get(0) * ai[(0, 0)]) / den;
br.set(0, temp_r);
bi.set(0, temp_i);
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::SMatrix;
#[test]
fn test_sol_simple() {
let mut a = Matrix::zeros(1, 1);
a[(0, 0)] = 2.0;
let mut b = SMatrix::<f64, 1, 1>::from_element(4.0);
let ip = vec![0];
lin_solve(&a, &mut b, &ip);
assert!((b[0] - 2.0_f64).abs() < 1e-10);
}
#[test]
fn test_solc_simple() {
let mut ar = Matrix::zeros(1, 1);
let mut ai = Matrix::zeros(1, 1);
ar[(0, 0)] = 1.0;
ai[(0, 0)] = 1.0; let mut br = SMatrix::<f64, 1, 1>::from_element(2.0);
let mut bi = SMatrix::<f64, 1, 1>::from_element(0.0); let ip = vec![0];
lin_solve_complex(&ar, &ai, &mut br, &mut bi, &ip);
assert!((br[(0, 0)] - 1.0).abs() < 1e-10);
assert!((bi[0] - (-1.0)).abs() < 1e-10);
}
}