use bulirsch::{self, Integrator};
use crate::Vec3;
use crate::errors::RocheError;
use crate::x_l1;
use pyo3::prelude::*;
#[pyfunction]
pub fn strinit(q: f64) -> Result<(Vec3, Vec3), RocheError> {
const SMALL: f64 = 1.0e-5;
let rl1: f64 = x_l1(q)?;
let mu: f64 = q/(1.0+q);
let a: f64 = (1.0-mu)/rl1.powi(3)+mu/(1.0-rl1).powi(3);
let lambda1: f64 = (((a-2.0) + (a*(9.0*a-8.0)).sqrt())/2.0).sqrt();
let m1: f64 = (lambda1*lambda1-2.0*a-1.0)/2.0/lambda1;
let r: Vec3 = Vec3::new(rl1-SMALL, -m1*SMALL, 0.0);
let v: Vec3 = Vec3::new(-lambda1*SMALL, -lambda1*m1*SMALL, 0.0);
Ok((r, v))
}
pub fn stradv(q: f64, r: &mut Vec3, v: &mut Vec3, rad: f64, acc: f64, smax: f64) -> f64 {
const TMAX: f64 = 10.0;
let t_next: f64 = 1.0e-2;
let mut time: f64 = 0.0;
let mut ro = *r;
let mut vo = *v;
let rinit: f64 = r.length();
let mut rnow: f64 = rinit;
let system = OrbitalSystem{ q: q };
let mut integrator = Integrator::default().with_abs_tol(1.0e-8).with_rel_tol(1.0e-8).into_adaptive();
let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
let mut y_next = ndarray::Array::zeros(y.raw_dim());
let mut yo = y.clone();
let mut delta_t = t_next.min(smax);
while (rinit > rad && rnow > rad) || (rinit < rad && rnow < rad) {
ro = *r;
vo = *v;
yo = y.clone();
integrator
.step(&system, delta_t, y.view(), y_next.view_mut())
.unwrap();
y.assign(&y_next);
r.set(y[0], y[1], y[2]);
v.set(y[3], y[4], y[5]);
rnow = r.length();
time += delta_t;
if time > TMAX {
panic!("roche::stradv taken too long without crossing given radius.")
}
}
let mut lo: f64 = 0.0;
let mut hi: f64 = delta_t;
let mut rlo: f64 = ro.length();
let mut rhi: f64 = rnow;
let to: f64 = time;
while (rhi-rlo).abs() > acc {
delta_t = (lo+hi)/2.0;
y = yo.clone();
*r = ro;
*v = vo;
time = to;
integrator
.step(&system, delta_t, y.view(), y_next.view_mut())
.unwrap();
y.assign(&y_next);
r.set(y[0], y[1], y[2]);
v.set(y[3], y[4], y[5]);
rnow = r.length();
if (rhi > rad && rnow > rad) || (rhi < rad && rnow < rad) {
rhi = rnow;
hi = delta_t;
} else {
rlo = rnow;
lo = delta_t;
}
}
time
}
#[pyfunction]
#[pyo3(name = "stradv")]
pub fn stradv_wrapper(q: f64, r: &Vec3, v: &Vec3, rad: f64, acc: f64, smax: f64) -> (f64, Vec3, Vec3) {
let mut r_mut = *r;
let mut v_mut = *v;
let timestep = stradv(q, &mut r_mut, &mut v_mut, rad, acc, smax);
(timestep, r_mut, v_mut)
}
#[pyfunction]
pub fn rocacc(q: f64, r: &Vec3, v: &Vec3) -> (f64, f64, f64) {
let f1: f64 = 1.0 / (1.0+q);
let f2: f64 = f1*q;
let yzsq: f64 = r.y*r.y + r.z*r.z;
let r1sq: f64 = r.x*r.x + yzsq;
let r2sq: f64 = (r.x-1.0)*(r.x-1.0) + yzsq;
let fm1: f64 = f1/(r1sq*(r1sq.sqrt()));
let fm2: f64 = f2/(r2sq*(r2sq.sqrt()));
let fm3 = fm1+fm2;
let x: f64 = -fm3*r.x + fm2 + 2.0*v.y + r.x - f2;
let y: f64 = -fm3*r.y - 2.0*v.x + r.y;
let z: f64 = -fm3*r.z;
(x, y, z)
}
struct OrbitalSystem {
q: f64,
}
impl bulirsch::System for OrbitalSystem {
type Float = f64;
fn system(&self, y: bulirsch::ArrayView1<Self::Float>, mut dydt: bulirsch::ArrayViewMut1<Self::Float>) {
dydt[[0]] = y[[3]];
dydt[[1]] = y[[4]];
dydt[[2]] = y[[5]];
let r = Vec3::new(y[[0]], y[[1]], y[[2]]);
let v = Vec3::new(y[[3]], y[[4]], y[[5]]);
(dydt[[3]], dydt[[4]], dydt[[5]]) = rocacc(self.q, &r, &v);
}
}