use self::BoundaryCondition::Dirichlet;
use self::ExMethod::{Euler, RK4};
use self::ODEOptions::{BoundCond, InitCond, Method, StepSize, StopCond, Times};
use self::ImMethod::{BDF1, GL4};
use operation::extra_ops::Real;
use operation::mut_ops::MutFP;
use std::collections::HashMap;
use structure::dual::Dual;
use structure::matrix::{Matrix, FP, LinearAlgebra};
use structure::vector::FPVector;
use numerical::utils::jacobian_real;
use util::non_macro::{cat, concat, zeros, eye};
use util::print::Printable;
use {VecOps, VecWithDual, Dualist};
#[cfg(feature = "oxidize")]
use {blas_daxpy, blas_daxpy_return};
#[derive(Debug, Copy, Clone, Hash, PartialOrd, PartialEq, Eq)]
pub enum ExMethod {
Euler,
RK4,
}
#[derive(Debug, Copy, Clone, Hash, PartialOrd, PartialEq, Eq)]
pub enum ImMethod {
BDF1,
GL4,
}
#[derive(Debug, Copy, Clone, Hash, PartialOrd, PartialEq, Eq)]
pub enum BoundaryCondition {
Dirichlet,
Neumann,
}
#[derive(Debug, Clone, Copy, Hash, PartialOrd, PartialEq, Eq)]
pub enum ODEOptions {
InitCond,
BoundCond,
Method,
StopCond,
StepSize,
Times,
}
#[derive(Debug, Clone, Default)]
pub struct State<T: Real> {
pub param: T,
pub value: Vec<T>,
pub deriv: Vec<T>,
}
impl<T: Real> State<T> {
pub fn to_f64(&self) -> State<f64> {
State {
param: self.param.to_f64(),
value: self
.value
.clone()
.into_iter()
.map(|x| x.to_f64())
.collect::<Vec<f64>>(),
deriv: self
.deriv
.clone()
.into_iter()
.map(|x| x.to_f64())
.collect::<Vec<f64>>(),
}
}
pub fn to_dual(&self) -> State<Dual> {
State {
param: self.param.to_dual(),
value: self
.value
.clone()
.into_iter()
.map(|x| x.to_dual())
.collect::<Vec<Dual>>(),
deriv: self
.deriv
.clone()
.into_iter()
.map(|x| x.to_dual())
.collect::<Vec<Dual>>(),
}
}
pub fn new(param: T, state: Vec<T>, deriv: Vec<T>) -> Self {
State {
param,
value: state,
deriv,
}
}
}
pub type ExUpdater = fn(&mut State<f64>);
pub type ImUpdater = fn(&mut State<Dual>);
pub trait ODE {
type Records;
type Param;
type ODEMethod;
fn mut_update(&mut self);
fn integrate(&mut self) -> Self::Records;
fn set_initial_condition<T: Real>(&mut self, init: State<T>) -> &mut Self;
fn set_boundary_condition<T: Real>(
&mut self,
bound1: (State<T>, BoundaryCondition),
bound2: (State<T>, BoundaryCondition),
) -> &mut Self;
fn set_step_size(&mut self, dt: f64) -> &mut Self;
fn set_method(&mut self, method: Self::ODEMethod) -> &mut Self;
fn set_stop_condition(&mut self, f: fn(&Self) -> bool) -> &mut Self;
fn set_times(&mut self, n: usize) -> &mut Self;
fn check_enough(&self) -> bool;
}
#[derive(Clone)]
pub struct ExplicitODE {
state: State<f64>,
func: fn(&mut State<f64>),
step_size: f64,
method: ExMethod,
init_cond: State<f64>,
bound_cond1: (State<f64>, BoundaryCondition),
bound_cond2: (State<f64>, BoundaryCondition),
stop_cond: fn(&Self) -> bool,
times: usize,
options: HashMap<ODEOptions, bool>,
}
impl ExplicitODE {
pub fn new(f: ExUpdater) -> Self {
let mut default_to_use: HashMap<ODEOptions, bool> = HashMap::new();
default_to_use.insert(InitCond, false);
default_to_use.insert(StepSize, false);
default_to_use.insert(BoundCond, false);
default_to_use.insert(Method, false);
default_to_use.insert(StopCond, false);
default_to_use.insert(Times, false);
ExplicitODE {
state: Default::default(),
func: f,
step_size: 0.0,
method: Euler,
init_cond: Default::default(),
bound_cond1: (Default::default(), Dirichlet),
bound_cond2: (Default::default(), Dirichlet),
stop_cond: |_x| false,
times: 0,
options: default_to_use,
}
}
pub fn get_state(&self) -> &State<f64> {
&self.state
}
}
impl ODE for ExplicitODE {
type Records = Matrix;
type Param = f64;
type ODEMethod = ExMethod;
fn mut_update(&mut self) {
match self.method {
Euler => {
(self.func)(&mut self.state);
let dt = self.step_size;
match () {
#[cfg(feature = "oxidize")]
() => {
blas_daxpy(dt, &self.state.deriv, &mut self.state.value);
}
_ => {
self.state
.value
.mut_zip_with(|x, y| x + y * dt, &self.state.deriv);
}
}
self.state.param += dt;
}
RK4 => {
let h = self.step_size;
let h2 = h / 2f64;
let yn = self.state.value.clone();
(self.func)(&mut self.state);
let k1 = self.state.deriv.clone();
let k1_add = k1.s_mul(h2);
self.state.param += h2;
self.state.value.mut_zip_with(|x, y| x + y, &k1_add);
(self.func)(&mut self.state);
let k2 = self.state.deriv.clone();
let k2_add = k2.zip_with(|x, y| h2 * x - y, &k1_add);
self.state.value.mut_zip_with(|x, y| x + y, &k2_add);
(self.func)(&mut self.state);
let k3 = self.state.deriv.clone();
let k3_add = k3.zip_with(|x, y| h * x - y, &k2_add);
self.state.param += h2;
self.state.value.mut_zip_with(|x, y| x + y, &k3_add);
(self.func)(&mut self.state);
let k4 = self.state.deriv.clone();
for i in 0..k1.len() {
self.state.value[i] =
yn[i] + (k1[i] + 2f64 * k2[i] + 2f64 * k3[i] + k4[i]) * h / 6f64;
}
}
}
}
fn integrate(&mut self) -> Self::Records {
assert!(self.check_enough(), "Not enough fields!");
let mut result = zeros(self.times + 1, self.state.value.len() + 1);
result.subs_row(0, &cat(self.state.param, self.state.value.clone()));
match self.options.get(&StopCond) {
Some(stop) if *stop => {
let mut key = 1usize;
for i in 1..self.times + 1 {
self.mut_update();
result.subs_row(i, &cat(self.state.param, self.state.value.clone()));
key += 1;
if (self.stop_cond)(&self) {
println!("Reach the stop condition!");
print!("Current values are: ");
cat(self.state.param, self.state.value.clone()).print();
break;
}
}
return result.take_row(key);
}
_ => {
for i in 1..self.times + 1 {
self.mut_update();
result.subs_row(i, &cat(self.state.param, self.state.value.clone()));
}
return result;
}
}
}
fn set_initial_condition<T: Real>(&mut self, init: State<T>) -> &mut Self {
if let Some(x) = self.options.get_mut(&InitCond) {
*x = true
}
self.init_cond = init.to_f64();
self.state = init.to_f64();
self
}
fn set_boundary_condition<T: Real>(
&mut self,
bound1: (State<T>, BoundaryCondition),
bound2: (State<T>, BoundaryCondition),
) -> &mut Self {
if let Some(x) = self.options.get_mut(&BoundCond) {
*x = true
}
self.bound_cond1 = (bound1.0.to_f64(), bound1.1);
self.bound_cond2 = (bound2.0.to_f64(), bound2.1);
self
}
fn set_step_size(&mut self, dt: f64) -> &mut Self {
if let Some(x) = self.options.get_mut(&StepSize) {
*x = true
}
self.step_size = dt;
self
}
fn set_method(&mut self, method: Self::ODEMethod) -> &mut Self {
if let Some(x) = self.options.get_mut(&Method) {
*x = true
}
self.method = method;
self
}
fn set_stop_condition(&mut self, f: fn(&Self) -> bool) -> &mut Self {
if let Some(x) = self.options.get_mut(&StopCond) {
*x = true
}
self.stop_cond = f;
self
}
fn set_times(&mut self, n: usize) -> &mut Self {
if let Some(x) = self.options.get_mut(&Times) {
*x = true
}
self.times = n;
self
}
fn check_enough(&self) -> bool {
match self.options.get(&Method) {
Some(x) => {
if !*x {
return false;
}
}
None => {
return false;
}
}
match self.options.get(&StepSize) {
Some(x) => {
if !*x {
return false;
}
}
None => {
return false;
}
}
match self.options.get(&InitCond) {
None => {
return false;
}
Some(x) => {
if !*x {
match self.options.get(&BoundCond) {
None => {
return false;
}
Some(_) => (),
}
}
}
}
match self.options.get(&Times) {
None => {
return false;
}
Some(x) => {
if !*x {
return false;
}
}
}
true
}
}
#[derive(Clone)]
pub struct ImplicitODE {
state: State<Dual>,
func: fn(&mut State<Dual>),
step_size: f64,
rtol: f64,
method: ImMethod,
init_cond: State<f64>,
bound_cond1: (State<f64>, BoundaryCondition),
bound_cond2: (State<f64>, BoundaryCondition),
stop_cond: fn(&Self) -> bool,
times: usize,
options: HashMap<ODEOptions, bool>,
}
impl ImplicitODE {
pub fn new(f: ImUpdater) -> Self {
let mut default_to_use: HashMap<ODEOptions, bool> = HashMap::new();
default_to_use.insert(InitCond, false);
default_to_use.insert(StepSize, false);
default_to_use.insert(BoundCond, false);
default_to_use.insert(Method, false);
default_to_use.insert(StopCond, false);
default_to_use.insert(Times, false);
ImplicitODE {
state: Default::default(),
func: f,
step_size: 0.0,
rtol: 1e-6,
method: GL4,
init_cond: Default::default(),
bound_cond1: (Default::default(), Dirichlet),
bound_cond2: (Default::default(), Dirichlet),
stop_cond: |_x| false,
times: 0,
options: default_to_use,
}
}
pub fn get_state(&self) -> &State<Dual> {
&self.state
}
pub fn set_rtol(&mut self, rtol: f64) -> &mut Self {
self.rtol = rtol;
self
}
}
const SQRT3: f64 = 1.7320508075688772;
const GL4_TAB: [[f64; 3]; 2] = [
[0.5 - SQRT3 / 6f64, 0.25, 0.25 - SQRT3 / 6f64],
[0.5 + SQRT3 / 6f64, 0.25 + SQRT3 / 6f64, 0.25],
];
#[allow(non_snake_case)]
impl ODE for ImplicitODE {
type Records = Matrix;
type Param = Dual;
type ODEMethod = ImMethod;
fn mut_update(&mut self) {
match self.method {
BDF1 => {
unimplemented!()
}
GL4 => {
let f = |t: Dual, y: Vec<Dual>| {
let mut st = State::new(t, y.clone(), y);
(self.func)(&mut st);
st.deriv
};
let h = self.step_size;
let t = self.state.param;
let t1: Dual = t + GL4_TAB[0][0] * h;
let t2: Dual = t + GL4_TAB[1][0] * h;
let yn = &self.state.value;
let n = yn.len();
let k1_init: Vec<f64> = f(t, yn.clone()).values();
let k2_init: Vec<f64> = f(t, yn.clone()).values();
let mut k_curr: Vec<f64> = concat(k1_init.clone(), k2_init.clone());
let mut err = 1f64;
let g = |k: &Vec<Dual>| -> Vec<Dual> {
let k1 = k.take(n);
let k2 = k.skip(n);
concat(
f(
t1,
yn.add(
&k1.s_mul(GL4_TAB[0][1] * h)
.add(&k2.s_mul(GL4_TAB[0][2]*h)),
),
),
f(
t2,
yn.add(
&k1.s_mul(GL4_TAB[1][1] * h)
.add(&k2.s_mul(GL4_TAB[1][2] * h)),
),
),
)
};
let I = eye(2 * n);
let mut Dg = jacobian_real(Box::new(g), &k_curr);
let mut DG = &I - &Dg;
let mut DG_inv = DG.inv().unwrap();
let mut G = k_curr.sub(&g(&k_curr.conv_dual()).values());
let mut num_iter: usize = 0;
while err >= self.rtol && num_iter <= 10 {
let DGG = &DG_inv * &G;
let k_prev = k_curr.clone();
k_curr.mut_zip_with(|x, y| x - y, &DGG.col(0));
Dg = jacobian_real(Box::new(g), &k_curr);
DG = &I - &Dg;
DG_inv = DG.inv().unwrap();
G = k_curr.sub(&g(&k_curr.conv_dual()).values());
err = k_curr.sub(&k_prev).norm();
num_iter += 1;
}
let (k1, k2) = (k_curr.take(n), k_curr.skip(n));
(self.func)(&mut self.state);
let mut y_curr = self.state.value.values();
y_curr = y_curr.add(&k1.s_mul(0.5 * h).add(&k2.s_mul(0.5 * h)));
self.state.value = y_curr.conv_dual();
self.state.param = self.state.param + h;
}
}
}
fn integrate(&mut self) -> Self::Records {
assert!(self.check_enough(), "Not enough fields!");
let mut result = zeros(self.times + 1, self.state.value.len() + 1);
result.subs_row(0, &cat(self.state.param.to_f64(), self.state.value.values()));
match self.options.get(&StopCond) {
Some(stop) if *stop => {
let mut key = 1usize;
for i in 1..self.times + 1 {
self.mut_update();
result.subs_row(i, &cat(self.state.param.to_f64(), self.state.value.values()));
key += 1;
if (self.stop_cond)(&self) {
println!("Reach the stop condition!");
print!("Current values are: ");
cat(self.state.param.to_f64(), self.state.value.values()).print();
break;
}
}
return result.take_row(key);
}
_ => {
for i in 1..self.times + 1 {
self.mut_update();
result.subs_row(i, &cat(self.state.param.to_f64(), self.state.value.values()));
}
return result;
}
}
}
fn set_initial_condition<T: Real>(&mut self, init: State<T>) -> &mut Self {
if let Some(x) = self.options.get_mut(&InitCond) {
*x = true
}
self.init_cond = init.to_f64();
self.state = init.to_dual();
self
}
fn set_boundary_condition<T: Real>(
&mut self,
bound1: (State<T>, BoundaryCondition),
bound2: (State<T>, BoundaryCondition),
) -> &mut Self {
if let Some(x) = self.options.get_mut(&BoundCond) {
*x = true
}
self.bound_cond1 = (bound1.0.to_f64(), bound1.1);
self.bound_cond2 = (bound2.0.to_f64(), bound2.1);
self
}
fn set_step_size(&mut self, dt: f64) -> &mut Self {
if let Some(x) = self.options.get_mut(&StepSize) {
*x = true
}
self.step_size = dt;
self
}
fn set_method(&mut self, method: Self::ODEMethod) -> &mut Self {
if let Some(x) = self.options.get_mut(&Method) {
*x = true
}
self.method = method;
self
}
fn set_stop_condition(&mut self, f: fn(&Self) -> bool) -> &mut Self {
if let Some(x) = self.options.get_mut(&StopCond) {
*x = true
}
self.stop_cond = f;
self
}
fn set_times(&mut self, n: usize) -> &mut Self {
if let Some(x) = self.options.get_mut(&Times) {
*x = true
}
self.times = n;
self
}
fn check_enough(&self) -> bool {
match self.options.get(&Method) {
Some(x) => {
if !*x {
return false;
}
}
None => {
return false;
}
}
match self.options.get(&StepSize) {
Some(x) => {
if !*x {
return false;
}
}
None => {
return false;
}
}
match self.options.get(&InitCond) {
None => {
return false;
}
Some(x) => {
if !*x {
match self.options.get(&BoundCond) {
None => {
return false;
}
Some(_) => (),
}
}
}
}
match self.options.get(&Times) {
None => {
return false;
}
Some(x) => {
if !*x {
return false;
}
}
}
true
}
}