use crate::StrError;
use russell_lab::Vector;
use russell_sparse::{CooMatrix, Sym};
use std::marker::PhantomData;
pub type NoArgs = u8;
pub struct System<'a, F, A>
where
F: Fn(&mut Vector, f64, &Vector, &mut A) -> Result<(), StrError>,
{
pub(crate) ndim: usize,
pub(crate) function: F,
pub(crate) jacobian: Option<Box<dyn Fn(&mut CooMatrix, f64, f64, &Vector, &mut A) -> Result<(), StrError> + 'a>>,
pub(crate) jac_nnz: usize,
pub(crate) jac_sym: Sym,
pub(crate) mass_matrix: Option<CooMatrix>,
phantom: PhantomData<fn() -> A>,
}
impl<'a, F, A> System<'a, F, A>
where
F: Fn(&mut Vector, f64, &Vector, &mut A) -> Result<(), StrError>,
{
pub fn new(ndim: usize, function: F) -> Self {
System {
ndim,
function,
jacobian: None,
jac_nnz: ndim * ndim,
jac_sym: Sym::No,
mass_matrix: None,
phantom: PhantomData,
}
}
pub fn set_jacobian(
&mut self,
jac_nnz: Option<usize>,
jac_sym: Sym,
callback: impl Fn(&mut CooMatrix, f64, f64, &Vector, &mut A) -> Result<(), StrError> + 'a,
) {
self.jac_nnz = if let Some(nnz) = jac_nnz {
nnz
} else {
self.ndim * self.ndim
};
self.jac_sym = jac_sym;
self.jacobian = Some(Box::new(callback));
}
pub fn init_mass_matrix(&mut self, max_nnz: usize) -> Result<(), StrError> {
if self.jacobian.is_none() {
return Err("the Jacobian function must be enabled first");
}
self.mass_matrix = Some(CooMatrix::new(self.ndim, self.ndim, max_nnz, self.jac_sym).unwrap());
Ok(())
}
pub fn mass_put(&mut self, i: usize, j: usize, value: f64) -> Result<(), StrError> {
match self.mass_matrix.as_mut() {
Some(mass) => mass.put(i, j, value),
None => Err("mass matrix has not been initialized/enabled"),
}
}
pub fn get_ndim(&self) -> usize {
self.ndim
}
pub fn get_jac_nnz(&self) -> usize {
self.jac_nnz
}
}
#[cfg(test)]
mod tests {
use super::System;
use crate::NoArgs;
use russell_lab::Vector;
use russell_sparse::{CooMatrix, Sym};
#[test]
fn ode_system_works() {
struct Args {
n_function_eval: usize,
more_data_goes_here: bool,
}
let mut args = Args {
n_function_eval: 0,
more_data_goes_here: false,
};
let system = System::new(2, |f, x, y, args: &mut Args| {
args.n_function_eval += 1;
f[0] = -x * y[1];
f[1] = x * y[0];
args.more_data_goes_here = true;
Ok(())
});
assert_eq!(system.get_ndim(), 2);
assert_eq!(system.get_jac_nnz(), 4);
let x = 0.0;
let y = Vector::new(2);
let mut k = Vector::new(2);
(system.function)(&mut k, x, &y, &mut args).unwrap();
assert!(system.jacobian.is_none());
println!("n_function_eval = {}", args.n_function_eval);
assert_eq!(args.n_function_eval, 1);
assert_eq!(args.more_data_goes_here, true);
}
#[test]
fn ode_system_set_jacobian_works() {
struct Args {
n_function_eval: usize,
n_jacobian_eval: usize,
more_data_goes_here_fn: bool,
more_data_goes_here_jj: bool,
}
let mut args = Args {
n_function_eval: 0,
n_jacobian_eval: 0,
more_data_goes_here_fn: false,
more_data_goes_here_jj: false,
};
let mut system = System::new(2, |f, x, y, args: &mut Args| {
args.n_function_eval += 1;
f[0] = -x * y[1];
f[1] = x * y[0];
args.more_data_goes_here_fn = true;
Ok(())
});
system.set_jacobian(Some(2), Sym::No, |jj, alpha, x, _y, args: &mut Args| {
args.n_jacobian_eval += 1;
jj.reset();
jj.put(0, 1, alpha * (-x)).unwrap();
jj.put(1, 0, alpha * (x)).unwrap();
args.more_data_goes_here_jj = true;
Ok(())
});
system.init_mass_matrix(2).unwrap(); system.mass_put(0, 0, 1.0).unwrap();
system.mass_put(1, 1, 1.0).unwrap();
let x = 0.0;
let y = Vector::new(2);
let mut k = Vector::new(2);
(system.function)(&mut k, x, &y, &mut args).unwrap();
let mut jj = CooMatrix::new(2, 2, 2, Sym::No).unwrap();
let alpha = 1.0;
(system.jacobian.as_ref().unwrap())(&mut jj, alpha, x, &y, &mut args).unwrap();
println!("n_function_eval = {}", args.n_function_eval);
println!("n_jacobian_eval = {}", args.n_jacobian_eval);
assert_eq!(args.n_function_eval, 1);
assert_eq!(args.n_jacobian_eval, 1);
assert_eq!(args.more_data_goes_here_fn, true);
assert_eq!(args.more_data_goes_here_jj, true);
}
#[test]
fn ode_system_handles_errors() {
let mut system = System::new(1, |f, _, _, _: &mut NoArgs| {
f[0] = 1.0;
Ok(())
});
let mut f = Vector::new(1);
let x = 0.0;
let y = Vector::new(1);
let mut args = 0;
(system.function)(&mut f, x, &y, &mut args).unwrap();
assert_eq!(
system.mass_put(0, 0, 1.0).err(),
Some("mass matrix has not been initialized/enabled")
);
assert_eq!(
system.init_mass_matrix(1).err(),
Some("the Jacobian function must be enabled first")
);
}
}