use crate::{
dae::{AlgebraicNumericalMethod, DAE},
error::Error,
linalg::{Matrix, lin_solve, lin_solve_complex, lu_decomp, lu_decomp_complex},
methods::{Algebraic, h_init::InitialStepSize, irk::radau::Radau5},
stats::Evals,
status::Status,
traits::{Real, State},
utils::{constrain_step_size, validate_step_size_parameters},
};
impl<T: Real, Y: State<T>> AlgebraicNumericalMethod<T, Y> for Radau5<Algebraic, T, Y> {
fn init<F>(&mut self, dae: &F, t0: T, tf: T, y0: &Y) -> Result<Evals, Error<T, Y>>
where
F: DAE<T, Y>,
{
let mut evals = Evals::new();
if self.h0 == T::zero() {
self.h0 = InitialStepSize::<Algebraic>::compute(
dae, t0, tf, y0, 5, &self.rtol, &self.atol, self.h_min, self.h_max, &mut evals,
);
}
match validate_step_size_parameters::<T, Y>(self.h0, self.h_min, self.h_max, t0, tf) {
Ok(h0) => {
self.h = (self.filter)(h0);
}
Err(status) => return Err(status),
}
self.initialize(t0, tf, y0)?;
let n = y0.len();
self.mass = Matrix::zeros(n, n);
dae.mass(&mut self.mass);
dae.diff(self.t, &self.y, &mut self.dydt);
evals.function += 1;
self.dydt_prev = self.dydt;
Ok(evals)
}
fn step<F>(&mut self, dae: &F) -> Result<Evals, Error<T, Y>>
where
F: DAE<T, Y>,
{
let mut evals = Evals::new();
let n = self.y.len();
if self.call_jac {
evals.jacobian += 1;
dae.jacobian(self.t, &self.y, &mut self.jacobian);
self.call_jac = false;
}
if self.call_decomp {
let fac1 = self.u1 / self.h;
let alphn = self.alph / self.h;
let betan = self.beta / self.h;
for j in 0..n {
for i in 0..n {
self.e1[(i, j)] = self.mass[(i, j)] * fac1 - self.jacobian[(i, j)];
}
}
if lu_decomp(&mut self.e1, &mut self.ip1).is_err() {
self.singular_count += 1;
if self.singular_count > 5 {
self.status = Status::Error(Error::LinearAlgebra {
msg: "Repeated singular matrix in step rejection; aborting.".to_string(),
});
return Err(Error::LinearAlgebra {
msg: "Repeated singular matrix in step rejection; aborting.".to_string(),
});
}
self.unexpected_step_rejection();
return Ok(evals);
}
for j in 0..n {
for i in 0..n {
let m = self.mass[(i, j)];
self.e2r[(i, j)] = m * alphn - self.jacobian[(i, j)];
self.e2i[(i, j)] = m * betan;
}
}
if lu_decomp_complex(&mut self.e2r, &mut self.e2i, &mut self.ip2).is_err() {
self.singular_count += 1;
if self.singular_count > 5 {
self.status = Status::Error(Error::LinearAlgebra {
msg: "Repeated singular matrix in step rejection; aborting.".to_string(),
});
return Err(Error::LinearAlgebra {
msg: "Repeated singular matrix in step rejection; aborting.".to_string(),
});
}
self.unexpected_step_rejection();
return Ok(evals);
}
evals.decompositions += 1;
}
self.steps += 1;
if self.steps >= self.max_steps {
self.status = Status::Error(Error::MaxSteps {
t: self.t,
y: self.y,
});
return Err(Error::MaxSteps {
t: self.t,
y: self.y,
});
}
if self.h.abs() < self.h_prev.abs() * self.uround {
self.status = Status::Error(Error::StepSize {
t: self.t,
y: self.y,
});
return Err(Error::StepSize {
t: self.t,
y: self.y,
});
}
for &i in &self.index2 {
let val = self.scal.get(i) / self.hhfac;
self.scal.set(i, val);
}
for &i in &self.index3 {
let val = self.scal.get(i) / (self.hhfac * self.hhfac);
self.scal.set(i, val);
}
if self.first {
for i in 0..3 {
self.z[i] = Y::zeros();
self.f[i] = Y::zeros();
}
} else {
let c3q = self.h / self.h_prev;
let c1q = self.c1 * c3q;
let c2q = self.c2 * c3q;
let ak1 = self.cont[1];
let ak2 = self.cont[2];
let ak3 = self.cont[3];
self.z[0] = (ak1 + (ak2 + ak3 * (c1q - self.c1m1)) * (c1q - self.c2m1)) * c1q;
self.z[1] = (ak1 + (ak2 + ak3 * (c2q - self.c1m1)) * (c2q - self.c2m1)) * c2q;
self.z[2] = (ak1 + (ak2 + ak3 * (c3q - self.c1m1)) * (c3q - self.c2m1)) * c3q;
self.f[0] = self.z[0] * self.tinv[(0, 0)]
+ self.z[1] * self.tinv[(0, 1)]
+ self.z[2] * self.tinv[(0, 2)];
self.f[1] = self.z[0] * self.tinv[(1, 0)]
+ self.z[1] * self.tinv[(1, 1)]
+ self.z[2] * self.tinv[(1, 2)];
self.f[2] = self.z[0] * self.tinv[(2, 0)]
+ self.z[1] * self.tinv[(2, 1)]
+ self.z[2] * self.tinv[(2, 2)];
}
self.faccon = self.faccon.max(self.uround).powf(T::from_f64(0.8).unwrap());
self.theta = self.thet.abs();
let mut newt_iter: usize = 0;
'newton: loop {
if newt_iter >= self.max_newton_iter {
self.unexpected_step_rejection();
return Ok(evals);
}
newt_iter += 1;
let t1 = self.t + self.c1 * self.h;
let t2 = self.t + self.c2 * self.h;
let t3 = self.t + self.h;
let y1 = self.y + self.z[0];
let y2 = self.y + self.z[1];
let y3 = self.y + self.z[2];
dae.diff(t1, &y1, &mut self.k[0]);
dae.diff(t2, &y2, &mut self.k[1]);
dae.diff(t3, &y3, &mut self.k[2]);
evals.function += 3;
self.z[0] = self.k[0] * self.tinv[(0, 0)]
+ self.k[1] * self.tinv[(0, 1)]
+ self.k[2] * self.tinv[(0, 2)];
self.z[1] = self.k[0] * self.tinv[(1, 0)]
+ self.k[1] * self.tinv[(1, 1)]
+ self.k[2] * self.tinv[(1, 2)];
self.z[2] = self.k[0] * self.tinv[(2, 0)]
+ self.k[1] * self.tinv[(2, 1)]
+ self.k[2] * self.tinv[(2, 2)];
let fac1 = self.u1 / self.h;
let alphn = self.alph / self.h;
let betan = self.beta / self.h;
for i in 0..n {
let mut s1 = T::zero();
let mut s2 = T::zero();
let mut s3 = T::zero();
for j in 0..n {
let mij = self.mass[(i, j)];
s1 -= mij * self.f[0].get(j);
s2 -= mij * self.f[1].get(j);
s3 -= mij * self.f[2].get(j);
}
self.z[0].set(i, self.z[0].get(i) + s1 * fac1);
self.z[1].set(i, self.z[1].get(i) + s2 * alphn - s3 * betan);
self.z[2].set(i, self.z[2].get(i) + s3 * alphn + s2 * betan);
}
lin_solve(&self.e1, &mut self.z[0], &self.ip1);
let (z12, z3) = self.z.split_at_mut(2);
let z2 = &mut z12[1];
let z3 = &mut z3[0];
lin_solve_complex(&self.e2r, &self.e2i, z2, z3, &self.ip2);
evals.solves += 2;
evals.newton += 1;
let mut dyno = T::zero();
for i in 0..n {
let sc = self.scal.get(i);
let v1 = self.z[0].get(i) / sc;
let v2 = self.z[1].get(i) / sc;
let v3 = self.z[2].get(i) / sc;
dyno = dyno + v1 * v1 + v2 * v2 + v3 * v3;
}
dyno = (dyno / T::from_f64((3 * n) as f64).unwrap()).sqrt();
if newt_iter > 1 && newt_iter < self.max_newton_iter {
let thq = dyno / self.dynold;
if newt_iter == 2 {
self.theta = thq;
} else {
self.theta = (thq * self.thqold).sqrt();
}
self.thqold = thq;
if self.theta < T::from_f64(0.99).unwrap() {
self.faccon = self.theta / (T::one() - self.theta);
let remaining_iters = (self.max_newton_iter - 1 - newt_iter) as f64;
let dyth =
self.faccon * dyno * self.theta.powf(T::from_f64(remaining_iters).unwrap())
/ self.newton_tol;
if dyth >= T::one() {
let qnewt = T::from_f64(1e-4)
.unwrap()
.max(T::from_f64(20.0).unwrap().min(dyth));
let exponent = -T::one() / T::from_f64(4.0 + remaining_iters).unwrap();
self.hhfac = T::from_f64(0.8).unwrap() * qnewt.powf(exponent);
self.h *= self.hhfac;
self.h = (self.filter)(self.h);
self.status = Status::RejectedStep;
self.reject = true;
return Ok(evals);
}
} else {
self.unexpected_step_rejection();
return Ok(evals);
}
}
self.dynold = dyno.max(self.uround);
self.f[0] += self.z[0];
self.f[1] += self.z[1];
self.f[2] += self.z[2];
self.z[0] = self.f[0] * self.tmat[(0, 0)]
+ self.f[1] * self.tmat[(0, 1)]
+ self.f[2] * self.tmat[(0, 2)];
self.z[1] = self.f[0] * self.tmat[(1, 0)]
+ self.f[1] * self.tmat[(1, 1)]
+ self.f[2] * self.tmat[(1, 2)];
self.z[2] = self.f[0] * self.tmat[(2, 0)] + self.f[1];
if self.faccon * dyno > self.newton_tol {
continue 'newton;
} else {
break 'newton;
}
}
let hee1 = self.dd1 / self.h;
let hee2 = self.dd2 / self.h;
let hee3 = self.dd3 / self.h;
let mut f1 = self.z[0] * hee1 + self.z[1] * hee2 + self.z[2] * hee3;
let mut f2 = Y::zeros();
let mut cont = Y::zeros();
for i in 0..n {
let mut sum = T::zero();
for j in 0..n {
sum += self.mass[(i, j)] * f1.get(j);
}
f2.set(i, sum);
cont.set(i, sum + self.dydt.get(i));
}
lin_solve(&self.e1, &mut cont, &self.ip1);
evals.solves += 1;
let mut err = T::zero();
for i in 0..n {
let r = cont.get(i) / self.scal.get(i);
err += r * r;
}
let mut err = (err / T::from_usize(n).unwrap())
.sqrt()
.max(T::from_f64(1e-10).unwrap());
if err >= T::one() && (self.first || self.reject) {
cont = self.y + cont;
dae.diff(self.t, &cont, &mut f1);
evals.function += 1;
cont = f1 + f2;
lin_solve(&self.e1, &mut cont, &self.ip1);
evals.solves += 1;
err = T::zero();
for i in 0..n {
let r = cont.get(i) / self.scal.get(i);
err += r * r;
}
err = (err / T::from_usize(n).unwrap())
.sqrt()
.max(T::from_f64(1e-10).unwrap());
}
let fac = self.safety_factor.min(
self.cfac
/ (T::from_usize(newt_iter).unwrap()
+ T::from_f64(2.0).unwrap() * T::from_usize(self.max_newton_iter).unwrap()),
);
let mut quot = self
.facr
.max(self.facl.min(err.powf(T::from_f64(0.25).unwrap()) / fac));
let mut hnew = self.h / quot;
if err < T::one() {
self.first = false;
self.n_accepted += 1;
if self.predictive {
if self.n_accepted > 1 {
let mut facgus = (self.h_acc / self.h)
* (err * err / self.err_acc).powf(T::from_f64(0.25).unwrap())
/ self.safety_factor;
facgus = self.facr.max(self.facl.min(facgus));
quot = quot.max(facgus);
hnew = self.h / quot;
}
self.h_acc = self.h;
self.err_acc = err.max(T::from_f64(1e-2).unwrap());
}
self.t_prev = self.t;
self.y_prev = self.y;
self.dydt_prev = self.dydt;
self.h_prev = self.h;
self.y += self.z[2];
self.t += self.h;
dae.diff(self.t, &self.y, &mut self.dydt);
evals.function += 1;
self.cont[0] = self.y;
self.cont[1] = (self.z[1] - self.z[2]) / self.c2m1;
let ak = (self.z[0] - self.z[1]) / self.c1mc2;
let acont3 = (ak - (self.z[0] / self.c1)) / self.c2;
self.cont[2] = (ak - self.cont[1]) / self.c1m1;
self.cont[3] = self.cont[2] - acont3;
for i in 0..n {
self.scal
.set(i, self.atol[i] + self.rtol[i] * self.y.get(i).abs());
}
hnew = constrain_step_size(hnew, self.h_min, self.h_max);
if self.reject {
let posneg = self.h.signum();
hnew = posneg * hnew.abs().min(self.h.abs());
self.reject = false;
self.status = Status::Solving;
}
hnew = (self.filter)(hnew);
let qt = hnew / self.h;
self.hhfac = self.h;
if self.theta < self.thet && qt > self.quot1 && qt < self.quot2 {
self.call_decomp = false;
self.call_jac = false;
return Ok(evals);
};
self.h = hnew;
self.hhfac = self.h;
if self.theta < self.thet {
self.call_jac = false;
return Ok(evals);
}
self.call_jac = true;
self.call_decomp = true;
} else {
self.reject = true;
self.status = Status::RejectedStep;
if self.first {
self.h *= T::from_f64(0.1).unwrap();
self.hhfac = T::from_f64(0.1).unwrap();
} else {
self.hhfac = hnew / self.h;
self.h = hnew;
}
self.h = constrain_step_size(self.h, self.h_min, self.h_max);
self.h = (self.filter)(self.h);
}
Ok(evals)
}
fn t(&self) -> T {
self.t
}
fn y(&self) -> &Y {
&self.y
}
fn t_prev(&self) -> T {
self.t_prev
}
fn y_prev(&self) -> &Y {
&self.y_prev
}
fn h(&self) -> T {
self.h
}
fn set_h(&mut self, h: T) {
self.h = (self.filter)(h);
}
fn status(&self) -> &Status<T, Y> {
&self.status
}
fn set_status(&mut self, status: Status<T, Y>) {
self.status = status;
}
}