diffsol 0.1.9

A library for solving ordinary differential equations (ODEs) in Rust.
Documentation
use std::rc::Rc;

use anyhow::Result;
use sundials_sys::{
    realtype, SUNLinSolFree, SUNLinSolSetup, SUNLinSolSolve, SUNLinSol_Dense, SUNLinearSolver,
};

use crate::{
    ode_solver::sundials::sundials_check,
    op::linearise::LinearisedOp,
    vector::sundials::{get_suncontext, SundialsVector},
    LinearOp, Matrix, NonLinearOp, Op, SolverProblem, SundialsMatrix,
};

use super::LinearSolver;

pub struct SundialsLinearSolver<Op>
where
    Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
    linear_solver: Option<SUNLinearSolver>,
    problem: Option<SolverProblem<LinearisedOp<Op>>>,
    is_setup: bool,
    matrix: Option<SundialsMatrix>,
}

impl<Op> Default for SundialsLinearSolver<Op>
where
    Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
    fn default() -> Self {
        Self::new_dense()
    }
}

impl<Op> SundialsLinearSolver<Op>
where
    Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
    pub fn new_dense() -> Self {
        Self {
            linear_solver: None,
            problem: None,
            is_setup: false,
            matrix: None,
        }
    }
}

impl<Op> Drop for SundialsLinearSolver<Op>
where
    Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
    fn drop(&mut self) {
        if let Some(linear_solver) = self.linear_solver {
            unsafe { SUNLinSolFree(linear_solver) };
        }
    }
}

impl<Op> LinearSolver<Op> for SundialsLinearSolver<Op>
where
    Op: NonLinearOp<M = SundialsMatrix, V = SundialsVector, T = realtype>,
{
    fn set_problem(&mut self, problem: &SolverProblem<Op>) {
        let linearised_problem = problem.linearise();
        let matrix = SundialsMatrix::zeros(
            linearised_problem.f.nstates(),
            linearised_problem.f.nstates(),
        );
        let y0 = SundialsVector::new_serial(linearised_problem.f.nstates());
        let ctx = *get_suncontext();
        let linear_solver =
            unsafe { SUNLinSol_Dense(y0.sundials_vector(), matrix.sundials_matrix(), ctx) };
        self.matrix = Some(matrix);
        self.problem = Some(linearised_problem);
        self.linear_solver = Some(linear_solver);
    }

    fn set_linearisation(&mut self, x: &Op::V, t: Op::T) {
        Rc::<LinearisedOp<Op>>::get_mut(&mut self.problem.as_mut().expect("Problem not set").f)
            .unwrap()
            .set_x(x);
        let matrix = self.matrix.as_mut().expect("Matrix not set");
        let linear_solver = self.linear_solver.expect("Linear solver not set");
        self.problem.as_ref().unwrap().f.matrix_inplace(t, matrix);
        sundials_check(unsafe { SUNLinSolSetup(linear_solver, matrix.sundials_matrix()) }).unwrap();
        self.is_setup = true;
    }

    fn solve_in_place(&self, b: &mut Op::V) -> Result<()> {
        if !self.is_setup {
            return Err(anyhow::anyhow!("Linear solver not setup"));
        }
        let linear_solver = self.linear_solver.expect("Linear solver not set");
        let matrix = self.matrix.as_ref().expect("Matrix not set");
        let tol = 1e-6;
        sundials_check(unsafe {
            SUNLinSolSolve(
                linear_solver,
                matrix.sundials_matrix(),
                b.sundials_vector(),
                b.sundials_vector(),
                tol,
            )
        })
    }
}