use crate::StrError;
use russell_lab::Vector;
use russell_sparse::{CooMatrix, Sym};
use std::sync::Arc;
pub type NoArgs = u8;
pub struct System<'a, A> {
pub(crate) ndim: usize,
pub(crate) function: Arc<dyn Fn(&mut Vector, f64, &Vector, &mut A) -> Result<(), StrError> + Send + Sync + 'a>,
pub(crate) jacobian:
Option<Arc<dyn Fn(&mut CooMatrix, f64, f64, &Vector, &mut A) -> Result<(), StrError> + Send + Sync + 'a>>,
pub(crate) calc_mass: Option<Arc<dyn Fn(&mut CooMatrix) + Send + Sync + 'a>>,
pub(crate) jac_nnz: usize,
pub(crate) mass_nnz: usize,
sym_jac: Option<Sym>,
sym_mass: Option<Sym>,
pub(crate) symmetric: Sym,
}
impl<'a, A> System<'a, A> {
pub fn new(
ndim: usize,
function: impl Fn(&mut Vector, f64, &Vector, &mut A) -> Result<(), StrError> + Send + Sync + 'a,
) -> Self {
System {
ndim,
function: Arc::new(function),
jacobian: None,
calc_mass: None,
jac_nnz: ndim * ndim,
mass_nnz: 0,
sym_jac: None,
sym_mass: None,
symmetric: Sym::No,
}
}
pub fn clone(&self) -> Self {
System {
ndim: self.ndim,
function: self.function.clone(),
jacobian: self.jacobian.clone(),
calc_mass: self.calc_mass.clone(),
jac_nnz: self.jac_nnz,
mass_nnz: self.mass_nnz,
sym_jac: self.sym_jac,
sym_mass: self.sym_mass,
symmetric: self.symmetric,
}
}
pub fn set_jacobian(
&mut self,
nnz: Option<usize>,
symmetric: Sym,
callback: impl Fn(&mut CooMatrix, f64, f64, &Vector, &mut A) -> Result<(), StrError> + Send + Sync + 'a,
) -> Result<(), StrError> {
if let Some(sym) = self.sym_mass {
if symmetric != sym {
return Err("the Jacobian matrix must have the same symmetric type as the mass matrix");
}
}
self.jac_nnz = if let Some(value) = nnz {
value
} else {
if symmetric.triangular() {
(self.ndim + self.ndim * self.ndim) / 2
} else {
self.ndim * self.ndim
}
};
self.sym_jac = Some(symmetric);
self.symmetric = symmetric;
self.jacobian = Some(Arc::new(callback));
Ok(())
}
pub fn set_mass(
&mut self,
nnz: Option<usize>,
symmetric: Sym,
callback: impl Fn(&mut CooMatrix) + Send + Sync + 'a,
) -> Result<(), StrError> {
if let Some(sym) = self.sym_jac {
if symmetric != sym {
return Err("the mass matrix must have the same symmetric type as the Jacobian matrix");
}
}
self.mass_nnz = if let Some(value) = nnz {
value
} else {
if symmetric.triangular() {
(self.ndim + self.ndim * self.ndim) / 2
} else {
self.ndim * self.ndim
}
};
self.sym_mass = Some(symmetric);
self.symmetric = symmetric;
self.calc_mass = Some(Arc::new(callback));
Ok(())
}
pub fn get_ndim(&self) -> usize {
self.ndim
}
pub fn get_jac_nnz(&self) -> usize {
self.jac_nnz
}
pub fn get_mass_nnz(&self) -> usize {
self.mass_nnz
}
}
#[cfg(test)]
mod tests {
use super::System;
use crate::NoArgs;
use russell_lab::Vector;
use russell_sparse::{CooMatrix, Sym};
#[test]
fn ode_system_handles_errors() {
let mut system = System::new(1, |f, _, _, _: &mut NoArgs| {
f[0] = 1.0;
Ok(())
});
let mut f = Vector::new(1);
let x = 0.0;
let y = Vector::new(1);
let mut args = 0;
(system.function)(&mut f, x, &y, &mut args).unwrap();
let jac_cb = |_: &mut CooMatrix, _: f64, _: f64, _: &Vector, _: &mut NoArgs| Ok(());
let mas_cb = |_: &mut CooMatrix| ();
let mut jj = CooMatrix::new(1, 1, 1, Sym::YesLower).unwrap();
let mut mm = CooMatrix::new(1, 1, 1, Sym::YesLower).unwrap();
let y = Vector::new(1);
(jac_cb)(&mut jj, 0.0, 0.0, &y, &mut 0).unwrap();
(mas_cb)(&mut mm);
system.set_jacobian(None, Sym::YesLower, jac_cb).unwrap();
assert_eq!(
system.set_mass(None, Sym::YesUpper, mas_cb).err(),
Some("the mass matrix must have the same symmetric type as the Jacobian matrix")
);
system.sym_jac = None;
system.set_mass(None, Sym::YesLower, mas_cb).unwrap();
assert_eq!(
system.set_jacobian(None, Sym::YesUpper, jac_cb).err(),
Some("the Jacobian matrix must have the same symmetric type as the mass matrix")
);
system.set_jacobian(None, Sym::YesLower, jac_cb).unwrap(); }
#[test]
fn ode_system_works() {
struct Args {
n_function_eval: usize,
more_data_goes_here: bool,
}
let mut args = Args {
n_function_eval: 0,
more_data_goes_here: false,
};
let system = System::new(2, |f, x, y, args: &mut Args| {
args.n_function_eval += 1;
f[0] = -x * y[1];
f[1] = x * y[0];
args.more_data_goes_here = true;
Ok(())
});
assert_eq!(system.get_ndim(), 2);
assert_eq!(system.get_jac_nnz(), 4);
let x = 0.0;
let y = Vector::new(2);
let mut k = Vector::new(2);
(system.function)(&mut k, x, &y, &mut args).unwrap();
assert!(system.jacobian.is_none());
println!("n_function_eval = {}", args.n_function_eval);
assert_eq!(args.n_function_eval, 1);
assert_eq!(args.more_data_goes_here, true);
let clone = system.clone();
assert_eq!(clone.ndim, 2);
assert_eq!(clone.jac_nnz, 4);
assert_eq!(clone.mass_nnz, 0);
assert_eq!(clone.sym_jac, None);
assert_eq!(clone.sym_mass, None);
assert_eq!(clone.symmetric, Sym::No);
}
#[test]
fn ode_system_set_jacobian_works() {
struct Args {
n_function_eval: usize,
n_jacobian_eval: usize,
more_data_goes_here_fn: bool,
more_data_goes_here_jj: bool,
}
let mut args = Args {
n_function_eval: 0,
n_jacobian_eval: 0,
more_data_goes_here_fn: false,
more_data_goes_here_jj: false,
};
let mut system = System::new(2, |f, x, y, args: &mut Args| {
args.n_function_eval += 1;
f[0] = -x * y[1];
f[1] = x * y[0];
args.more_data_goes_here_fn = true;
Ok(())
});
let symmetric = Sym::No;
system
.set_jacobian(Some(2), symmetric, |jj, alpha, x, _y, args: &mut Args| {
args.n_jacobian_eval += 1;
jj.reset();
jj.put(0, 1, alpha * (-x)).unwrap();
jj.put(1, 0, alpha * (x)).unwrap();
args.more_data_goes_here_jj = true;
Ok(())
})
.unwrap();
let x = 0.0;
let y = Vector::new(2);
let mut k = Vector::new(2);
(system.function)(&mut k, x, &y, &mut args).unwrap();
let mut jj = CooMatrix::new(2, 2, 2, Sym::No).unwrap();
let alpha = 1.0;
(system.jacobian.as_ref().unwrap())(&mut jj, alpha, x, &y, &mut args).unwrap();
println!("n_function_eval = {}", args.n_function_eval);
println!("n_jacobian_eval = {}", args.n_jacobian_eval);
assert_eq!(args.n_function_eval, 1);
assert_eq!(args.n_jacobian_eval, 1);
assert_eq!(args.more_data_goes_here_fn, true);
assert_eq!(args.more_data_goes_here_jj, true);
}
#[test]
fn ode_system_set_mass_works() {
let mut system = System::new(2, |f, _, _, _: &mut NoArgs| {
f[0] = 1.0;
f[1] = 1.0;
Ok(())
});
let mut f = Vector::new(2);
let x = 0.0;
let y = Vector::new(2);
let mut args = 0;
(system.function)(&mut f, x, &y, &mut args).unwrap();
let mas_cb = |_: &mut CooMatrix| ();
let mut mm = CooMatrix::new(2, 2, 4, Sym::YesLower).unwrap();
(mas_cb)(&mut mm);
system.set_mass(None, Sym::YesLower, mas_cb).unwrap();
assert_eq!(system.get_mass_nnz(), 3);
system.set_mass(None, Sym::No, mas_cb).unwrap();
assert_eq!(system.get_mass_nnz(), 4);
}
}