use crate::errors::RocheError;
use crate::potential::{drpot, rpot};
use crate::x_lagrange::x_l1;
use crate::{Star, Vec3};
use pyo3::prelude::*;
use numpy::{IntoPyArray, PyArray1};
pub struct LineRoche {
pub q: f64,
pub star: Star,
pub dx: f64,
pub dy: f64,
pub cpot: f64,
}
impl LineRoche {
pub fn new(q: f64, star: Star, dx: f64, dy: f64, cpot: f64) -> Self {
Self {
q,
star,
dx,
dy,
cpot,
}
}
pub fn cost(&self, lam: f64) -> Result<(f64, f64), RocheError> {
let p: Vec3 = match self.star {
Star::Primary => Vec3::new(lam * self.dx, lam * self.dy, 0.0),
Star::Secondary => Vec3::new(1.0 + lam * self.dx, lam * self.dy, 0.0),
};
let f: f64 = rpot(self.q, &p)? - self.cpot;
let dp: Vec3 = drpot(self.q, &p)?;
let d: f64 = self.dx * dp.x + self.dy * dp.y;
Ok((f, d))
}
}
pub fn lobe1(q: f64, n: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
const FRAC: f64 = 1.0e-6;
let rl1: f64 = x_l1(q)?;
let p: Vec3 = Vec3::new(rl1, 0.0, 0.0);
let cpot: f64 = rpot(q, &p)?;
let mut xarr: Vec<f64> = Vec::with_capacity(n);
let mut yarr: Vec<f64> = Vec::with_capacity(n);
for i in 0..n {
if i == 0 || i == n - 1 {
xarr.push(rl1);
yarr.push(0.0);
} else {
let theta: f64 = (i as f64) * std::f64::consts::PI * 2.0 / ((n as f64) - 1.0);
let dx: f64 = theta.cos();
let dy: f64 = theta.sin();
let line: LineRoche = LineRoche::new(q, Star::Primary, dx, dy, cpot);
let lam: f64 = rtsafe(rl1 / 4.0, rl1, |lam| line.cost(lam), FRAC)?;
xarr.push(lam * dx);
yarr.push(lam * dy);
}
}
Ok((xarr, yarr))
}
#[pyfunction]
#[pyo3(name = "lobe1", signature = (q, n=200))]
pub fn lobe1_py(py: Python, q: f64, n: usize) -> PyResult<(Py<PyArray1<f64>>, Py<PyArray1<f64>>)> {
let (xarr, yarr) = lobe1(q, n)?;
Ok((xarr.into_pyarray(py).unbind(), yarr.into_pyarray(py).unbind()))
}
pub fn lobe2(q: f64, n: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
const FRAC: f64 = 1.0e-6;
let rl1: f64 = x_l1(q)?;
let p: Vec3 = Vec3::new(rl1, 0.0, 0.0);
let cpot: f64 = rpot(q, &p)?;
let upper: f64 = 1.0 - rl1;
let lower: f64 = upper / 4.0;
let mut xarr: Vec<f64> = Vec::with_capacity(n);
let mut yarr: Vec<f64> = Vec::with_capacity(n);
for i in 0..n {
if i == 0 || i == n - 1 {
xarr.push(rl1);
yarr.push(0.0);
} else {
let theta: f64 = (i as f64) * std::f64::consts::PI * 2.0 / ((n as f64) - 1.0);
let dx: f64 = -theta.cos();
let dy: f64 = theta.sin();
let line: LineRoche = LineRoche::new(q, Star::Secondary, dx, dy, cpot);
let lam: f64 = rtsafe(lower, upper, |lam| line.cost(lam), FRAC)?;
xarr.push(1.0 + lam * dx);
yarr.push(lam * dy);
}
}
Ok((xarr, yarr))
}
#[pyfunction]
#[pyo3(name = "lobe2", signature = (q, n=200))]
pub fn lobe2_py(py: Python, q: f64, n: usize) -> PyResult<(Py<PyArray1<f64>>, Py<PyArray1<f64>>)> {
let (xarr, yarr) = lobe2(q, n)?;
Ok((xarr.into_pyarray(py).unbind(), yarr.into_pyarray(py).unbind()))
}
pub fn vlobe1(q: f64, n: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
let mut tvx: f64;
let mut tvy: f64;
let mut vx_arr: Vec<f64> = vec![];
let mut vy_arr: Vec<f64> = vec![];
let (x, y) = lobe1(q, n)?;
let mu: f64 = q / (1.0 + q);
for i in 0..n {
tvx = -y[i];
tvy = x[i] - mu;
vx_arr.push(tvx);
vy_arr.push(tvy);
}
Ok((vx_arr, vy_arr))
}
#[pyfunction]
#[pyo3(name = "vlobe1", signature = (q, n=200))]
pub fn vlobe1_py(py: Python, q: f64, n: usize) -> PyResult<(Py<PyArray1<f64>>, Py<PyArray1<f64>>)> {
let (xarr, yarr) = vlobe1(q, n)?;
Ok((xarr.into_pyarray(py).unbind(), yarr.into_pyarray(py).unbind()))
}
pub fn vlobe2(q: f64, n: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
let mut tvx: f64;
let mut tvy: f64;
let mut vx_arr: Vec<f64> = vec![];
let mut vy_arr: Vec<f64> = vec![];
let (x, y) = lobe2(q, n)?;
let mu: f64 = q / (1.0 + q);
for i in 0..n {
tvx = -y[i];
tvy = x[i] - mu;
vx_arr.push(tvx);
vy_arr.push(tvy);
}
Ok((vx_arr, vy_arr))
}
#[pyfunction]
#[pyo3(name = "vlobe2", signature = (q, n=200))]
pub fn vlobe2_py(py: Python, q: f64, n: usize) -> PyResult<(Py<PyArray1<f64>>, Py<PyArray1<f64>>)> {
let (xarr, yarr) = vlobe2(q, n)?;
Ok((xarr.into_pyarray(py).unbind(), yarr.into_pyarray(py).unbind()))
}
pub fn rtsafe<F>(x1: f64, x2: f64, func: F, xacc: f64) -> Result<f64, RocheError>
where
F: Fn(f64) -> Result<(f64, f64), RocheError>,
{
let mut xlo = x1;
let mut xhi = x2;
let mut fl;
let mut fh;
let mut df;
const MAXITER: i32 = 100;
(fl, _) = func(xlo)?;
(fh, _) = func(xhi)?;
if (fl > 0.0 && fh > 0.0) || (fl < 0.0 && fh < 0.0) {
return Err(RocheError::RtsafeError(
"Root must be bracketed in rtsafe".to_string(),
));
}
if fl == 0.0 {
return Ok(xlo);
} else if fh == 0.0 {
return Ok(xhi);
}
if fh < 0.0 {
std::mem::swap(&mut xlo, &mut xhi);
std::mem::swap(&mut fl, &mut fh);
}
let mut rts = 0.5 * (xlo + xhi);
let mut dxold = (xhi - xlo).abs();
let mut dx = dxold;
let mut f;
(f, df) = func(rts)?;
let mut iter = 0;
while iter < MAXITER {
if ((rts - xhi) * df - f) * ((rts - xlo) * df - f) >= 0.0
|| ((2.0 * f).abs() > (dxold * df).abs())
{
dxold = dx;
dx = 0.5 * (xhi - xlo);
rts = xlo + dx;
if xlo == rts {
return Ok(rts);
}
} else {
dxold = dx;
dx = f / df;
let temp = rts;
rts -= dx;
if temp == rts {
return Ok(rts);
}
}
if dx.abs() < xacc {
return Ok(rts);
}
(f, df) = func(rts)?;
if f < 0.0 {
xlo = rts;
} else {
xhi = rts;
}
iter += 1;
}
Err(RocheError::RtsafeError(
"Maximum number of iterations exceeded in rtsafe".to_string(),
))
}