1use crate::errors::RocheError;
2use crate::x_l1;
3use crate::{Vec3, vel_transform};
4use bulirsch::{self, Integrator};
5use pyo3::prelude::*;
6use numpy::{IntoPyArray, PyArray1};
7
8#[pyfunction]
22pub fn strinit(q: f64) -> Result<(Vec3, Vec3), RocheError> {
23 const SMALL: f64 = 1.0e-5;
24 let rl1: f64 = x_l1(q)?;
25 let mu: f64 = q / (1.0 + q);
26 let a: f64 = (1.0 - mu) / rl1.powi(3) + mu / (1.0 - rl1).powi(3);
27 let lambda1: f64 = (((a - 2.0) + (a * (9.0 * a - 8.0)).sqrt()) / 2.0).sqrt();
28 let m1: f64 = (lambda1 * lambda1 - 2.0 * a - 1.0) / 2.0 / lambda1;
29
30 let r: Vec3 = Vec3::new(rl1 - SMALL, -m1 * SMALL, 0.0);
31 let v: Vec3 = Vec3::new(-lambda1 * SMALL, -lambda1 * m1 * SMALL, 0.0);
32
33 Ok((r, v))
34}
35
36pub fn stream(q: f64, step: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
55 if n_points < 2 {
56 return Err(RocheError::ParameterError(
57 "Need at least 2 points in the stream.".to_string(),
58 ));
59 }
60
61 if step <= 0.0 || step > 1.0 {
62 return Err(RocheError::ParameterError(
63 "Step size must be between 0.0 and 1.0".to_string(),
64 ));
65 }
66
67 if q <= 0.0 {
68 return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
69 }
70
71 let mut x_arr: Vec<f64> = vec![];
72 let mut y_arr: Vec<f64> = vec![];
73
74 let rl1: f64 = x_l1(q)?;
76 let (mut r, mut v) = strinit(q)?;
77
78 x_arr.push(rl1);
80 y_arr.push(0.0);
81
82 let mut lp: usize = 0;
83
84 let mut dist: f64 = (r.x - rl1).hypot(r.y);
88
89 let frac: f64;
90
91 if dist > step {
92 frac = step / dist;
93 x_arr.push(rl1 + (r.x - rl1) * frac);
94 y_arr.push(r.y * frac);
95 lp += 1;
96 }
97
98 let system = OrbitalSystem { q };
100 let mut integrator = Integrator::default()
101 .with_abs_tol(1.0e-8)
102 .with_rel_tol(1.0e-8)
103 .into_adaptive();
104 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
106 let mut y_next = ndarray::Array::zeros(y.raw_dim());
107
108 let mut delta_t: f64 = 1.0e-3;
109 let smax: f64 = (1.0e-3_f64).min(step / 2.0);
110
111 let mut vel: f64;
112 while lp < n_points - 1 {
113 integrator
114 .step(&system, delta_t, y.view(), y_next.view_mut())
115 .unwrap();
116 y.assign(&y_next);
117
118 r.set(y[0], y[1], y[2]);
119 v.set(y[3], y[4], y[5]);
120 dist = (r.x - x_arr[lp]).hypot(r.y - y_arr[lp]);
121 if dist > step {
122 let frac: f64 = step / dist;
123 x_arr.push(x_arr[lp] + (r.x - x_arr[lp]) * frac);
124 y_arr.push(y_arr[lp] + (r.y - y_arr[lp]) * frac);
125 lp += 1;
126 }
127 vel = v.x.hypot(v.y);
128 delta_t = (smax / vel).min(delta_t);
129 }
130
131 Ok((x_arr, y_arr))
132}
133
134#[pyfunction]
153#[pyo3(name = "stream", signature = (q, step, n_points=200))]
154pub fn stream_py(py: Python, q: f64, step: f64, n_points: usize) -> PyResult<(Py<PyArray1<f64>>, Py<PyArray1<f64>>)> {
155 let (x_arr, y_arr) = stream(q, step, n_points)?;
156 Ok((x_arr.into_pyarray(py).unbind(), y_arr.into_pyarray(py).unbind()))
157}
158
159pub fn strmnx(q: f64, r: &mut Vec3, v: &mut Vec3, acc: f64) -> Result<(), RocheError> {
172 let mut dir: f64;
173 let mut lo: f64;
174 let mut hi: f64;
175 let mut ro: Vec3 = *r;
176 let mut vo: Vec3 = *v;
177
178 let mut delta_t: f64 = 1.0e-2;
179
180 dir = r.dot(v);
182 let dir1: f64 = dir;
183
184 let system = OrbitalSystem { q };
186 let mut integrator = Integrator::default()
187 .with_abs_tol(1.0e-8)
188 .with_rel_tol(1.0e-8)
189 .into_adaptive();
190 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
192 let mut y_next = ndarray::Array::zeros(y.raw_dim());
193 let mut yo = y.clone();
194
195 while (dir > 0.0 && dir1 > 0.0) || (dir < 0.0 && dir1 < 0.0) {
196 ro = *r;
197 vo = *v;
198 yo = y.clone();
199 integrator
200 .step(&system, delta_t, y.view(), y_next.view_mut())
201 .unwrap();
202 y.assign(&y_next);
203 r.set(y[0], y[1], y[2]);
204 v.set(y[3], y[4], y[5]);
205 dir = r.dot(v);
206 }
207
208 lo = 0.0;
212 hi = delta_t;
213 while (hi - lo).abs() > acc {
214 delta_t = (lo + hi) / 2.0;
215 y = yo.clone();
216 *r = ro;
217 *v = vo;
218 integrator
219 .step(&system, delta_t, y.view(), y_next.view_mut())
220 .unwrap();
221 y.assign(&y_next);
222
223 r.set(y[0], y[1], y[2]);
224 v.set(y[3], y[4], y[5]);
225 dir = r.dot(v);
226 if (dir > 0.0 && dir1 < 0.0) || (dir < 0.0 && dir1 > 0.0) {
227 hi = delta_t;
228 } else {
229 lo = delta_t;
230 }
231 }
232
233 Ok(())
234}
235
236#[pyfunction]
254#[pyo3(name = "strmnx")]
255#[pyo3(signature = (q, n=1, acc=1.0e-7))]
256pub fn strmnx_wrapper(
257 q: f64,
258 n: usize,
259 acc: f64,
260) -> Result<(f64, f64, f64, f64, f64, f64), RocheError> {
261 let (mut r, mut v) = strinit(q)?;
262 for _ in 0..n {
263 strmnx(q, &mut r, &mut v, acc)?
264 }
265 let (tvx1, tvy1) = vel_transform(q, 1, r.x, r.y, v.x, v.y)?;
266 let (tvx2, tvy2) = vel_transform(q, 2, r.x, r.y, v.x, v.y)?;
267 Ok((r.x, r.y, tvx1, tvy1, tvx2, tvy2))
268}
269
270pub fn streamr(q: f64, rad: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
287 if n_points < 2 {
288 return Err(RocheError::ParameterError(
289 "Need at least 2 points in the stream.".to_string(),
290 ));
291 }
292
293 if q <= 0.0 {
294 return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
295 }
296
297 const EPS: f64 = 1.0e-8;
298
299 let mut x_arr: Vec<f64> = vec![];
300 let mut y_arr: Vec<f64> = vec![];
301
302 let rl1: f64 = x_l1(q)?;
304 let (mut r, mut v) = strinit(q)?;
305 let rs = r;
306 let vs = v;
307 strmnx(q, &mut r, &mut v, EPS)?;
308 let rmin = if r.length() > rad { r.length() } else { rad };
309
310 r = rs;
311 v = vs;
312 x_arr.push(r.x);
313 y_arr.push(r.y);
314 let mut rnext: f64;
315 for i in 1..n_points {
316 rnext = rl1 + (rmin - rl1) * (i as f64) / (n_points as f64 - 1.0);
317 stradv(q, &mut r, &mut v, rnext, 1.0e-6, 1.0e-4);
318 x_arr.push(r.x);
319 y_arr.push(r.y);
320 }
321
322 Ok((x_arr, y_arr))
323}
324
325#[pyfunction]
342#[pyo3(name = "streamr", signature = (q, rad, n_points=200))]
343pub fn streamr_py(py: Python, q: f64, rad: f64, n_points: usize) -> PyResult<(Py<PyArray1<f64>>, Py<PyArray1<f64>>)> {
344 let (x_arr, y_arr) = streamr(q, rad, n_points)?;
345 Ok((x_arr.into_pyarray(py).unbind(), y_arr.into_pyarray(py).unbind()))
346}
347
348pub fn stradv(q: f64, r: &mut Vec3, v: &mut Vec3, rad: f64, acc: f64, smax: f64) -> f64 {
371 const TMAX: f64 = 10.0;
372 let t_next: f64 = 1.0e-2;
373
374 let mut time: f64 = 0.0;
375
376 let mut ro = *r;
378 let mut vo = *v;
379
380 let rinit: f64 = r.length();
382 let mut rnow: f64 = rinit;
383
384 let system = OrbitalSystem { q };
386 let mut integrator = Integrator::default()
387 .with_abs_tol(1.0e-8)
388 .with_rel_tol(1.0e-8)
389 .into_adaptive();
390 let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
392 let mut y_next = ndarray::Array::zeros(y.raw_dim());
393
394 let mut yo = y.clone();
395 let mut delta_t = t_next.min(smax);
396 while (rinit > rad && rnow > rad) || (rinit < rad && rnow < rad) {
398 ro = *r;
399 vo = *v;
400 yo = y.clone();
401 integrator
402 .step(&system, delta_t, y.view(), y_next.view_mut())
403 .unwrap();
404 y.assign(&y_next);
405 r.set(y[0], y[1], y[2]);
406 v.set(y[3], y[4], y[5]);
407 rnow = r.length();
408 time += delta_t;
409
410 if time > TMAX {
411 panic!("roche::stradv taken too long without crossing given radius.")
412 }
413 }
414
415 let mut lo: f64 = 0.0;
419 let mut hi: f64 = delta_t;
420 let mut rlo: f64 = ro.length();
421 let mut rhi: f64 = rnow;
422 let to: f64 = time;
423
424 while (rhi - rlo).abs() > acc {
425 delta_t = (lo + hi) / 2.0;
426 y = yo.clone();
427 *r = ro;
428 *v = vo;
429 time = to;
430
431 integrator
432 .step(&system, delta_t, y.view(), y_next.view_mut())
433 .unwrap();
434 y.assign(&y_next);
435
436 r.set(y[0], y[1], y[2]);
437 v.set(y[3], y[4], y[5]);
438 rnow = r.length();
439
440 if (rhi > rad && rnow > rad) || (rhi < rad && rnow < rad) {
441 rhi = rnow;
442 hi = delta_t;
443 } else {
444 rlo = rnow;
445 lo = delta_t;
446 }
447 }
448
449 time
450}
451
452#[pyfunction]
472#[pyo3(name = "stradv")]
473pub fn stradv_py(
474 q: f64,
475 r: &Vec3,
476 v: &Vec3,
477 rad: f64,
478 acc: f64,
479 smax: f64,
480) -> (f64, Vec3, Vec3) {
481 let mut r_mut = *r;
482 let mut v_mut = *v;
483 let timestep = stradv(q, &mut r_mut, &mut v_mut, rad, acc, smax);
484 (timestep, r_mut, v_mut)
485}
486
487#[pyfunction]
496pub fn rocacc(q: f64, r: &Vec3, v: &Vec3) -> (f64, f64, f64) {
497 let f1: f64 = 1.0 / (1.0 + q);
498 let f2: f64 = f1 * q;
499
500 let yzsq: f64 = r.y * r.y + r.z * r.z;
501 let r1sq: f64 = r.x * r.x + yzsq;
502 let r2sq: f64 = (r.x - 1.0) * (r.x - 1.0) + yzsq;
503 let fm1: f64 = f1 / (r1sq * (r1sq.sqrt()));
504 let fm2: f64 = f2 / (r2sq * (r2sq.sqrt()));
505 let fm3: f64 = fm1 + fm2;
506
507 let x: f64 = -fm3 * r.x + fm2 + 2.0 * v.y + r.x - f2;
508 let y: f64 = -fm3 * r.y - 2.0 * v.x + r.y;
509 let z: f64 = -fm3 * r.z;
510 (x, y, z)
511}
512
513#[pyfunction]
528#[pyo3(signature = (q, rad, acc=1.0e-7, smax=1.0e-2))]
529pub fn brightspot_position(q: f64, rad: f64, acc: f64, smax: f64) -> Result<Vec3, RocheError> {
530 let (mut r, mut v) = strinit(q)?;
531 let _ = stradv(q, &mut r, &mut v, rad, acc, smax);
532
533 Ok(r)
534}
535
536#[pyfunction]
552#[pyo3(signature = (q, rad, acc=1.0e-7, smax=1.0e-2))]
553pub fn bspot(q: f64, rad: f64, acc: f64, smax: f64) -> Result<(Vec3, Vec3), RocheError> {
554 let (mut r, mut v) = strinit(q)?;
555 let _ = stradv(q, &mut r, &mut v, rad, acc, smax);
556
557 Ok((r, v))
558}
559
560pub struct OrbitalSystem {
561 pub q: f64,
562}
563
564impl bulirsch::System for OrbitalSystem {
565 type Float = f64;
566
567 fn system(
568 &self,
569 y: bulirsch::ArrayView1<Self::Float>,
570 mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
571 ) {
572 dydt[[0]] = y[[3]];
573 dydt[[1]] = y[[4]];
574 dydt[[2]] = y[[5]];
575 let r = Vec3::new(y[[0]], y[[1]], y[[2]]);
576 let v = Vec3::new(y[[3]], y[[4]], y[[5]]);
577 (dydt[[3]], dydt[[4]], dydt[[5]]) = rocacc(self.q, &r, &v);
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 #[test]
586 fn strinit_stradv_test() -> Result<(), RocheError> {
587 let (mut r, mut v) = strinit(0.2)?;
589 let _time: f64 = stradv(0.2, &mut r, &mut v, 0.3, 1.0e-7, 1.0e-3);
590 assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
591 assert!((v - Vec3::new(-1.4769457229627583, 0.31712381217252994, 0.0)).length() < 1.0e-7);
592 Ok(())
593 }
594
595 #[test]
596 fn stream_test() -> Result<(), RocheError> {
597 let (x, y) = stream(0.2, 0.01, 200)?;
599 assert!((x[0] - 0.6585557).hypot(y[0] - 0.0) < 1.0e-4);
600 assert!((x[50] - 0.18384902).hypot(y[50] - 0.15145306) < 1.0e-4);
601 assert!((x[100] - -0.100431986).hypot(y[100] - -0.13697079) < 1.0e-4);
602 assert!((x[150] - 0.21720248).hypot(y[150] - -0.4577784) < 1.0e-4);
603 assert!((x[y.len() - 1] - 0.15403406).hypot(y[y.len() - 1] - 0.016731631) < 1.0e-4);
604 assert!(stream(-0.2, 0.0001, 200).is_err());
605 assert!(stream(0.2, 1.1, 200).is_err());
606 assert!(stream(0.2, -0.1, 200).is_err());
607 assert!(stream(0.2, 0.0001, 1).is_err());
608 Ok(())
609 }
610
611 #[test]
612 fn strmnx_test() -> Result<(), RocheError> {
613 let (x, y, vx1, vy1, vx2, vy2) = strmnx_wrapper(0.2, 1, 1.0e-7)?;
615 assert!(
616 (x - -0.08613947462186848).hypot(y - 0.05411592729509131)
617 / (-0.08613947462186848_f64).hypot(0.05411592729509131)
618 < 1.0e-6
619 );
620 assert!(
621 (vx1 - -1.9727409465489645).hypot(vy1 - -3.30679322752132)
622 / (-1.9727409465489645_f64).hypot(-3.30679322752132)
623 < 1.0e-6
624 );
625 assert!(
626 (vx2 - -1.5225623467338747).hypot(vy2 - -2.5902178683586605)
627 / (-1.5225623467338747_f64).hypot(-2.5902178683586605)
628 < 1.0e-6
629 );
630 Ok(())
631 }
632
633 #[test]
634 fn brightspot_position_test() -> Result<(), RocheError> {
635 let r = brightspot_position(0.2, 0.3, 1.0e-7, 1.0e-3)?;
637 assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
638 Ok(())
639 }
640
641 #[test]
642 fn bspot_test() -> Result<(), RocheError> {
643 let (r, v) = bspot(0.2, 0.3, 1.0e-7, 1.0e-3)?;
645 assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
646 assert!((v - Vec3::new(-1.476945722613775, 0.31712381223279495, 0.0)).length() < 1.0e-6);
647 Ok(())
648 }
649}