use crate::error::Error;
use crate::linalg::Matrix;
use crate::methods::irk::radau::Radau5;
use crate::status::Status;
use crate::traits::{Real, State};
impl<E, T: Real, Y: State<T>> Radau5<E, T, Y> {
pub fn initialize(&mut self, t0: T, tf: T, y0: &Y) -> Result<(), Error<T, Y>> {
let n = y0.len();
if self.uround <= T::from_f64(1e-19).unwrap() || self.uround >= T::one() {
let e = Error::BadInput {
msg: "UROUND is out of range (expected ~1e-16).".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
let ten = T::from_f64(10.0).unwrap();
if self.atol[0] <= T::zero() || self.rtol[0] <= ten * self.uround {
let e = Error::BadInput {
msg: "Tolerances are too small (require ATOL > 0 and RTOL > 10*UROUND)."
.to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
for i in 0..n {
let quot = self.atol[i] / self.rtol[i];
let expm = T::from_f64(2.0 / 3.0).unwrap();
self.rtol[i] = T::from_f64(0.1).unwrap() * self.rtol[i].powf(expm);
self.atol[i] = self.rtol[i] * quot;
}
if self.max_steps == 0 {
let e = Error::BadInput {
msg: "max_steps must be > 0".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
if self.max_newton_iter == 0 {
let e = Error::BadInput {
msg: "max_newton_iter must be > 0".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
if !(self.safety_factor > T::from_f64(0.001).unwrap() && self.safety_factor < T::one()) {
let e = Error::BadInput {
msg: "safety_factor must be in (0.001, 1.0)".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
if self.thet >= T::one() {
let e = Error::BadInput {
msg: "thet must be < 1.0".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
let tolst = self.rtol[0]; if self.newton_tol <= T::zero() {
let upper = T::from_f64(0.03).unwrap().min(tolst.sqrt());
let lower = T::from_f64(10.0).unwrap() * self.uround / tolst;
self.newton_tol = lower.max(upper);
} else {
let min_allowed = self.uround / tolst;
if self.newton_tol <= min_allowed {
let e = Error::BadInput {
msg: "newton_tol too small (<= UROUND/RTOL')".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
}
if self.quot1 > T::one() || self.quot2 < T::one() {
let e = Error::BadInput {
msg: "Invalid (quot1, quot2): require quot1 <= 1 and quot2 >= 1".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
if !self.h_max.is_finite() || self.h_max <= T::zero() {
self.h_max = (tf - t0).abs();
}
if self.min_scale <= T::zero() || self.min_scale > T::one() {
let e = Error::BadInput {
msg: "min_scale must be in (0, 1] (default 0.2)".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
if self.max_scale < T::one() {
let e = Error::BadInput {
msg: "max_scale must be >= 1 (default 8.0)".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
self.facl = T::one() / self.min_scale;
self.facr = T::one() / self.max_scale;
if self.facl < T::one() || self.facr > T::one() {
let e = Error::BadInput {
msg: "Invalid clamp factors derived from scales (facl>=1, facr<=1)".to_string(),
};
self.status = Status::Error(e.clone());
return Err(e);
}
self.cfac = self.safety_factor
* (T::one() + T::from_f64(2.0).unwrap() * T::from_usize(self.max_newton_iter).unwrap());
self.steps = 0;
self.rejects = 0;
self.n_accepted = 0;
self.jacobian_age = 0;
self.t = t0;
self.tf = tf;
self.y = *y0;
self.mass = Matrix::identity(n);
self.dydt = Y::zeros();
self.t_prev = self.t;
self.y_prev = self.y;
self.dydt_prev = self.dydt;
self.h_prev = self.h;
self.hhfac = self.h;
for i in 0..n {
self.scal
.set(i, self.atol[i] + self.rtol[i] * self.y.get(i).abs());
}
self.z = [Y::zeros(), Y::zeros(), Y::zeros()];
self.k = [Y::zeros(), Y::zeros(), Y::zeros()];
self.f = [Y::zeros(), Y::zeros(), Y::zeros()];
self.jacobian = Matrix::zeros(n, n);
self.e1 = Matrix::zeros(n, n);
self.e2r = Matrix::zeros(n, n);
self.e2i = Matrix::zeros(n, n);
self.ip1 = vec![0; n];
self.ip2 = vec![0; n];
self.a = Matrix::zeros(2 * n, 2 * n);
self.b = vec![T::zero(); 2 * n];
self.cont = [Y::zeros(); 4];
self.first = true;
self.reject = false;
self.status = Status::Initialized;
Ok(())
}
}