neco-eigensolve 0.1.0

Lightweight solvers for generalized eigenvalue problems
Documentation
#[cfg(feature = "faer-lu")]
use faer::c64 as FaerC64;
#[cfg(feature = "faer-lu")]
use faer::col::Col;
#[cfg(feature = "faer-lu")]
use faer::linalg::solvers::Solve;
#[cfg(feature = "faer-lu")]
use faer::sparse::linalg::solvers::{Lu, SymbolicLu};
#[cfg(feature = "faer-lu")]
use faer::sparse::{SparseColMat, Triplet};

use neco_sparse::CsrMat;

use crate::c64::C64;

use super::backend::PreparedLinearSolver;
use super::internal_lu::ShiftedCsrData;

#[cfg(feature = "faer-lu")]
fn to_faer_c64(v: C64) -> FaerC64 {
    FaerC64::new(v.re, v.im)
}

#[cfg(feature = "faer-lu")]
fn from_faer_c64(v: FaerC64) -> C64 {
    C64::new(v.re, v.im)
}

#[cfg(feature = "faer-lu")]
fn build_complex_shifted_matrix(
    k_mat: &CsrMat<f64>,
    m_mat: &CsrMat<f64>,
    z: C64,
) -> Result<SparseColMat<usize, FaerC64>, String> {
    let shifted = ShiftedCsrData::from_shift(k_mat, m_mat, z);
    shifted
        .diagonal_positions()
        .map_err(|err| format!("direct LU shifted matrix is invalid: {err}"))?;
    let mut triplets = Vec::with_capacity(shifted.values.len());
    for row_idx in 0..shifted.n {
        let start = shifted.row_offsets[row_idx];
        let end = shifted.row_offsets[row_idx + 1];
        for pos in start..end {
            let val = shifted.values[pos];
            let abs2 = val.re * val.re + val.im * val.im;
            let col = shifted.col_indices[pos];
            if abs2 > 1e-40 || col == row_idx {
                triplets.push(Triplet::new(row_idx, col, to_faer_c64(val)));
            }
        }
    }

    SparseColMat::try_new_from_triplets(shifted.n, shifted.n, &triplets)
        .map_err(|e| format!("failed to build direct LU shift matrix: {e}"))
}

#[cfg(feature = "faer-lu")]
pub(crate) struct DirectLuShiftedSolver {
    lu: Lu<usize, FaerC64>,
}

#[cfg(feature = "faer-lu")]
impl DirectLuShiftedSolver {
    pub(crate) fn new(k_mat: &CsrMat<f64>, m_mat: &CsrMat<f64>, z: C64) -> Result<Self, String> {
        let shifted = build_complex_shifted_matrix(k_mat, m_mat, z)?;
        let symbolic_ref = shifted.as_ref().symbolic();
        let symbolic = SymbolicLu::try_new(symbolic_ref)
            .map_err(|e| format!("direct LU symbolic factorization failed: {e}"))?;
        let lu = Lu::try_new_with_symbolic(symbolic, shifted.as_ref())
            .map_err(|e| format!("direct LU numeric factorization failed: {e}"))?;
        Ok(Self { lu })
    }
}

#[cfg(feature = "faer-lu")]
impl PreparedLinearSolver for DirectLuShiftedSolver {
    fn solve_block(&self, rhs: &[f64], n: usize, m0: usize, _tol: f64) -> Result<Vec<C64>, String> {
        if rhs.len() != n * m0 {
            return Err(format!(
                "rhs length {} does not match n * m0 = {}",
                rhs.len(),
                n * m0
            ));
        }

        let mut result = Vec::with_capacity(n * m0);
        let mut col_buf = Col::<FaerC64>::zeros(n);
        for j in 0..m0 {
            let col_start = j * n;
            for i in 0..n {
                col_buf[i] = FaerC64::new(rhs[col_start + i], 0.0);
            }
            let x = self.lu.solve(&col_buf);
            for i in 0..n {
                result.push(from_faer_c64(x[i]));
            }
        }
        Ok(result)
    }
}

#[cfg(all(test, feature = "faer-lu"))]
mod tests {
    use super::*;

    fn diagonal_csr(n: usize, diag: &[f64]) -> CsrMat<f64> {
        let offsets: Vec<usize> = (0..=n).collect();
        let indices: Vec<usize> = (0..n).collect();
        CsrMat::try_from_csr_data(n, n, offsets, indices, diag.to_vec()).unwrap()
    }

    #[test]
    fn direct_lu_diagonal_system() {
        let k = diagonal_csr(3, &[1.0, 2.0, 3.0]);
        let m = diagonal_csr(3, &[1.0, 1.0, 1.0]);
        let solver = DirectLuShiftedSolver::new(&k, &m, C64::new(2.5, 1.0)).unwrap();

        let rhs = vec![1.0f64; 3];
        let result = solver.solve_block(&rhs, 3, 1, 1e-10).unwrap();
        let expected = [
            C64::new(1.0, 0.0) / C64::new(1.5, 1.0),
            C64::new(1.0, 0.0) / C64::new(0.5, 1.0),
            C64::new(1.0, 0.0) / C64::new(-0.5, 1.0),
        ];
        for i in 0..3 {
            let diff = result[i] - expected[i];
            let err = (diff.re * diff.re + diff.im * diff.im).sqrt();
            assert!(err < 1e-10, "x[{i}] err={err}");
        }
    }
}