Skip to main content

roche/
stream_physics.rs

1use crate::errors::RocheError;
2use crate::x_l1;
3use crate::{Vec3, vel_transform};
4use bulirsch::{self, Integrator};
5use pyo3::prelude::*;
6
7///
8/// strinit sets a particle just inside the L1 point with the
9/// correct velocity as given in Lubow and Shu.
10///
11/// Arguments:
12///
13/// * `q`: mass ratio = M2/M1
14///
15/// Returns:
16///
17/// * start position
18/// * start velocity
19///
20#[pyfunction]
21pub fn strinit(q: f64) -> Result<(Vec3, Vec3), RocheError> {
22    const SMALL: f64 = 1.0e-5;
23    let rl1: f64 = x_l1(q)?;
24    let mu: f64 = q / (1.0 + q);
25    let a: f64 = (1.0 - mu) / rl1.powi(3) + mu / (1.0 - rl1).powi(3);
26    let lambda1: f64 = (((a - 2.0) + (a * (9.0 * a - 8.0)).sqrt()) / 2.0).sqrt();
27    let m1: f64 = (lambda1 * lambda1 - 2.0 * a - 1.0) / 2.0 / lambda1;
28
29    let r: Vec3 = Vec3::new(rl1 - SMALL, -m1 * SMALL, 0.0);
30    let v: Vec3 = Vec3::new(-lambda1 * SMALL, -lambda1 * m1 * SMALL, 0.0);
31
32    Ok((r, v))
33}
34
35///
36/// stream works by integrating the equations of motion for the Roche
37/// potential using Burlisch-Stoer integration. Every time the distance
38/// from the last point exceeds step, it interpolates and stores a new
39/// point. This allows one not to spend loads of points on regions where
40/// nothing is happening.
41///
42/// Arguments:
43///
44/// * `q`:    mass ratio = M2/M1. Stream flows from star 2 to 1.
45/// * `step`: step between points (units of separation).
46/// * `n_points`:    number of points to compute.
47///
48/// Returns:
49///
50/// * `x`:    array of x values returned.
51/// * `y`:    array of y values returned.
52///
53#[pyfunction]
54#[pyo3(signature = (q, step, n_points=200))]
55pub fn stream(q: f64, step: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
56    if n_points < 2 {
57        return Err(RocheError::ParameterError(
58            "Need at least 2 points in the stream.".to_string(),
59        ));
60    }
61
62    if step <= 0.0 || step > 1.0 {
63        return Err(RocheError::ParameterError(
64            "Step size must be between 0.0 and 1.0".to_string(),
65        ));
66    }
67
68    if q <= 0.0 {
69        return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
70    }
71
72    let mut x_arr: Vec<f64> = vec![];
73    let mut y_arr: Vec<f64> = vec![];
74
75    // Initialise stream
76    let rl1: f64 = x_l1(q)?;
77    let (mut r, mut v) = strinit(q)?;
78
79    // Store L1 as first point
80    x_arr.push(rl1);
81    y_arr.push(0.0);
82
83    let mut lp: usize = 0;
84
85    // Store interpolation between L1 and initial point if
86    // step has been set small enough
87
88    let mut dist = (r.x - rl1).hypot(r.y);
89
90    let frac: f64;
91
92    if dist > step {
93        frac = step / dist;
94        x_arr.push(rl1 + (r.x - rl1) * frac);
95        y_arr.push(r.y * frac);
96        lp += 1;
97    }
98
99    // set up Bulirsch-Stoer integrator
100    let system = OrbitalSystem { q: q };
101    let mut integrator = Integrator::default()
102        .with_abs_tol(1.0e-8)
103        .with_rel_tol(1.0e-8)
104        .into_adaptive();
105    // Initialise arrays
106    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
107    let mut y_next = ndarray::Array::zeros(y.raw_dim());
108
109    let mut delta_t = 1.0e-3;
110    let smax = (1.0e-3_f64).min(step / 2.0);
111
112    let mut vel: f64;
113    while lp < n_points - 1 {
114        integrator
115            .step(&system, delta_t, y.view(), y_next.view_mut())
116            .unwrap();
117        y.assign(&y_next);
118
119        r.set(y[0], y[1], y[2]);
120        v.set(y[3], y[4], y[5]);
121        dist = (r.x - x_arr[lp]).hypot(r.y - y_arr[lp]);
122        if dist > step {
123            let frac: f64 = step / dist;
124            x_arr.push(x_arr[lp] + (r.x - x_arr[lp]) * frac);
125            y_arr.push(y_arr[lp] + (r.y - y_arr[lp]) * frac);
126            lp += 1;
127        }
128        vel = v.x.hypot(v.y);
129        delta_t = (smax / vel).min(delta_t);
130    }
131
132    Ok((x_arr, y_arr))
133}
134
135///
136/// strmnx finds the next point at which stream is closest or furthest
137/// from primary.
138///
139/// Arguments:
140///
141/// * `q`: mass ratio = M2/M1
142/// * `r`: initial and final position
143/// * `v`: initial and final velocity
144/// * `acc`: accuracy in time to locate minimum/maximum.
145///
146///
147pub fn strmnx(q: f64, r: &mut Vec3, v: &mut Vec3, acc: f64) -> Result<(), RocheError> {
148    let mut dir: f64;
149    let dir1: f64;
150    let mut lo: f64;
151    let mut hi: f64;
152    let mut ro: Vec3 = *r;
153    let mut vo: Vec3 = *v;
154
155    let mut delta_t: f64 = 1.0e-2;
156
157    // Store initial direction
158    dir = r.dot(v);
159    dir1 = dir;
160
161    // set up Bulirsch-Stoer integrator
162    let system = OrbitalSystem { q: q };
163    let mut integrator = Integrator::default()
164        .with_abs_tol(1.0e-8)
165        .with_rel_tol(1.0e-8)
166        .into_adaptive();
167    // Initialise arrays
168    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
169    let mut y_next = ndarray::Array::zeros(y.raw_dim());
170    let mut yo = y.clone();
171
172    while (dir > 0.0 && dir1 > 0.0) || (dir < 0.0 && dir1 < 0.0) {
173        ro = *r;
174        vo = *v;
175        yo = y.clone();
176        integrator
177            .step(&system, delta_t, y.view(), y_next.view_mut())
178            .unwrap();
179        y.assign(&y_next);
180        r.set(y[0], y[1], y[2]);
181        v.set(y[3], y[4], y[5]);
182        dir = r.dot(v);
183    }
184
185    //   Now refine by reinitialising and binary chopping until
186    //   close enough to requested radius.
187
188    lo = 0.0;
189    hi = delta_t;
190    while (hi - lo).abs() > acc {
191        delta_t = (lo + hi) / 2.0;
192        y = yo.clone();
193        *r = ro;
194        *v = vo;
195        integrator
196            .step(&system, delta_t, y.view(), y_next.view_mut())
197            .unwrap();
198        y.assign(&y_next);
199
200        r.set(y[0], y[1], y[2]);
201        v.set(y[3], y[4], y[5]);
202        dir = r.dot(v);
203        if (dir > 0.0 && dir1 < 0.0) || (dir < 0.0 && dir1 > 0.0) {
204            hi = delta_t;
205        } else {
206            lo = delta_t;
207        }
208    }
209
210    Ok(())
211}
212
213// wrapper for python library, avoiding mutable references
214
215///
216/// Calculates position & velocity of n-th turning point of stream.
217/// x,y,vx1,vy1,vx2,vy2 = strmnx(q, n=1, acc=1.e-7), q = M2/M1.
218/// Two sets of velocities are reported, the first for the pure stream,
219/// the second for the disk at that point.
220///
221/// Arguments:
222///
223/// * `q`: mass ratio = M2/M1
224/// * `n`: turning point number
225/// * `acc`: accuracy in time to locate minimum/maximum.
226///
227/// Returns:
228/// (x, y, vx1, vy1, vx2, vy2)
229///
230#[pyfunction]
231#[pyo3(name = "strmnx")]
232#[pyo3(signature = (q, n=1, acc=1.0e-7))]
233pub fn strmnx_wrapper(
234    q: f64,
235    n: usize,
236    acc: f64,
237) -> Result<(f64, f64, f64, f64, f64, f64), RocheError> {
238    let (mut r, mut v) = strinit(q)?;
239    for _ in 0..n {
240        strmnx(q, &mut r, &mut v, acc)?
241    }
242    let (tvx1, tvy1) = vel_transform(q, 1, r.x, r.y, v.x, v.y)?;
243    let (tvx2, tvy2) = vel_transform(q, 2, r.x, r.y, v.x, v.y)?;
244    Ok((r.x, r.y, tvx1, tvy1, tvx2, tvy2))
245}
246
247///
248/// streamr works by integrating the equations of motion for the Roche
249/// potential using Burlisch-Stoer integration. It stops when the stream
250/// reaches a target radius or a minimum radius, whichever is the larger.
251///
252/// Arguments:
253///
254/// * `q`: mass ratio = M2/M1. Stream flows from star 2 to 1.
255/// * `rad`: Radius to aim for. If this is less than the minimum, the stream will stop at the minimum
256/// * `n_points`: number of points to compute.
257///
258/// Results:
259///
260/// * `x`: array of x values returned.
261/// * `y`: array of y values returned.
262///
263#[pyfunction]
264#[pyo3(signature = (q, rad, n_points=200))]
265pub fn streamr(q: f64, rad: f64, n_points: usize) -> Result<(Vec<f64>, Vec<f64>), RocheError> {
266    if n_points < 2 {
267        return Err(RocheError::ParameterError(
268            "Need at least 2 points in the stream.".to_string(),
269        ));
270    }
271
272    if q <= 0.0 {
273        return Err(RocheError::ParameterError("q = {} <= 0".to_string()));
274    }
275
276    const EPS: f64 = 1.0e-8;
277
278    let mut x_arr: Vec<f64> = vec![];
279    let mut y_arr: Vec<f64> = vec![];
280
281    // Initialise stream
282    let rl1: f64 = x_l1(q)?;
283    let (mut r, mut v) = strinit(q)?;
284    let rs = r;
285    let vs = v;
286    strmnx(q, &mut r, &mut v, EPS)?;
287    let rmin = if r.length() > rad { r.length() } else { rad };
288
289    r = rs;
290    v = vs;
291    x_arr.push(r.x);
292    y_arr.push(r.y);
293    let mut rnext: f64;
294    for i in 1..n_points {
295        rnext = rl1 + (rmin - rl1) * (i as f64) / (n_points as f64 - 1.0);
296        stradv(q, &mut r, &mut v, rnext, 1.0e-6, 1.0e-4);
297        x_arr.push(r.x);
298        y_arr.push(r.y);
299    }
300
301    Ok((x_arr, y_arr))
302}
303
304///
305/// stradv advances a particle of given position and velocity until
306/// it reaches a specified radius. It then returns with updated position and
307/// velocity. It is up to the user not to request a value that cannot be reached.
308///
309/// Arguments:
310///
311/// * `q`:    mass ratio = M2/M1
312/// * `r`:    Initial and final position
313/// * `v`:    Initial and final velocity
314/// * `rad`:  Radius to aim for
315/// * `acc`:  Accuracy with which to place output point at rad.
316/// * `smax`: Largest time step allowed. It is possible that the
317/// routine could take such a large step that it misses
318/// the point when the stream is inside the requested
319/// radius. This allows one to control this. Typical
320/// value = 1.e-3.
321///
322/// Returns:
323///
324/// * time step taken
325///
326pub fn stradv(q: f64, r: &mut Vec3, v: &mut Vec3, rad: f64, acc: f64, smax: f64) -> f64 {
327    const TMAX: f64 = 10.0;
328    let t_next: f64 = 1.0e-2;
329
330    let mut time: f64 = 0.0;
331
332    // let to: f64;
333    let mut ro = *r;
334    let mut vo = *v;
335
336    // Store initial radius
337    let rinit: f64 = r.length();
338    let mut rnow: f64 = rinit;
339
340    // set up Bulirsch-Stoer integrator
341    let system = OrbitalSystem { q: q };
342    let mut integrator = Integrator::default()
343        .with_abs_tol(1.0e-8)
344        .with_rel_tol(1.0e-8)
345        .into_adaptive();
346    // Initialise arrays
347    let mut y = ndarray::array![r.x, r.y, r.z, v.x, v.y, v.z];
348    let mut y_next = ndarray::Array::zeros(y.raw_dim());
349
350    let mut yo = y.clone();
351    let mut delta_t = t_next.min(smax);
352    // Step until radius crossed
353    while (rinit > rad && rnow > rad) || (rinit < rad && rnow < rad) {
354        ro = *r;
355        vo = *v;
356        yo = y.clone();
357        integrator
358            .step(&system, delta_t, y.view(), y_next.view_mut())
359            .unwrap();
360        y.assign(&y_next);
361        r.set(y[0], y[1], y[2]);
362        v.set(y[3], y[4], y[5]);
363        rnow = r.length();
364        time += delta_t;
365
366        if time > TMAX {
367            panic!("roche::stradv taken too long without crossing given radius.")
368        }
369    }
370
371    // Now refine by reinitialising and binary chopping until
372    // close enough to requested radius.
373
374    let mut lo: f64 = 0.0;
375    let mut hi: f64 = delta_t;
376    let mut rlo: f64 = ro.length();
377    let mut rhi: f64 = rnow;
378    let to: f64 = time;
379
380    while (rhi - rlo).abs() > acc {
381        delta_t = (lo + hi) / 2.0;
382        y = yo.clone();
383        *r = ro;
384        *v = vo;
385        time = to;
386
387        integrator
388            .step(&system, delta_t, y.view(), y_next.view_mut())
389            .unwrap();
390        y.assign(&y_next);
391
392        r.set(y[0], y[1], y[2]);
393        v.set(y[3], y[4], y[5]);
394        rnow = r.length();
395
396        if (rhi > rad && rnow > rad) || (rhi < rad && rnow < rad) {
397            rhi = rnow;
398            hi = delta_t;
399        } else {
400            rlo = rnow;
401            lo = delta_t;
402        }
403    }
404
405    time
406}
407
408// wrapper for python library, avoiding mutable references
409
410///
411/// stradv advances a particle of given position and velocity until
412/// it reaches a specified radius. It then returns with updated position and
413/// velocity. It is up to the user not to request a value that cannot be reached.
414///
415/// \param q    mass ratio = M2/M1
416/// \param r    Initial position
417/// \param v    Initial velocity
418/// \param rad  Radius to aim for
419/// \param acc  Accuracy with which to place output point at rad.
420/// \param smax Largest time step allowed. It is possible that the
421/// routine could take such a large step that it misses
422/// the point when the stream is inside the requested
423/// radius. This allows one to control this. Typical
424/// value = 1.e-3.
425/// \returns (timestep, new position, new velocity)
426///
427#[pyfunction]
428#[pyo3(name = "stradv")]
429pub fn stradv_wrapper(
430    q: f64,
431    r: &Vec3,
432    v: &Vec3,
433    rad: f64,
434    acc: f64,
435    smax: f64,
436) -> (f64, Vec3, Vec3) {
437    let mut r_mut = *r;
438    let mut v_mut = *v;
439    let timestep = stradv(q, &mut r_mut, &mut v_mut, rad, acc, smax);
440    (timestep, r_mut, v_mut)
441}
442
443///
444/// rocacc calculates and returns the acceleration (in the rotating frame)
445/// in a Roche potential of a particle of given position and velocity.
446///
447/// \param q mass ratio = M2/M1
448/// \param r position, scaled in units of separation.
449/// \param v velocity, scaled in units of separation
450///
451#[pyfunction]
452pub fn rocacc(q: f64, r: &Vec3, v: &Vec3) -> (f64, f64, f64) {
453    let f1: f64 = 1.0 / (1.0 + q);
454    let f2: f64 = f1 * q;
455
456    let yzsq: f64 = r.y * r.y + r.z * r.z;
457    let r1sq: f64 = r.x * r.x + yzsq;
458    let r2sq: f64 = (r.x - 1.0) * (r.x - 1.0) + yzsq;
459    let fm1: f64 = f1 / (r1sq * (r1sq.sqrt()));
460    let fm2: f64 = f2 / (r2sq * (r2sq.sqrt()));
461    let fm3 = fm1 + fm2;
462
463    let x: f64 = -fm3 * r.x + fm2 + 2.0 * v.y + r.x - f2;
464    let y: f64 = -fm3 * r.y - 2.0 * v.x + r.y;
465    let z: f64 = -fm3 * r.z;
466    (x, y, z)
467}
468
469///
470/// brightspot_position runs strinit then stradv to get the coordinates of
471/// of the gas stream when it reaches a given radius from the primary star.
472///
473/// Arguments:
474///
475/// * `q`:  mass ratio = M2/M1
476/// * `rad`: radius from primary star
477/// * `acc`: computational accuracy
478/// * `smax`: maximum time step of Bulirsch-Stoer integration
479///
480/// Returns:
481/// * `r`: Vec3 coordinates of gas stream at given radius from primary star
482///
483#[pyfunction]
484#[pyo3(signature = (q, rad, acc=1.0e-7, smax=1.0e-2))]
485pub fn brightspot_position(q: f64, rad: f64, acc: f64, smax: f64) -> Result<Vec3, RocheError> {
486    let (mut r, mut v) = strinit(q)?;
487    let _ = stradv(q, &mut r, &mut v, rad, acc, smax);
488
489    Ok(r)
490}
491
492struct OrbitalSystem {
493    q: f64,
494}
495
496impl bulirsch::System for OrbitalSystem {
497    type Float = f64;
498
499    fn system(
500        &self,
501        y: bulirsch::ArrayView1<Self::Float>,
502        mut dydt: bulirsch::ArrayViewMut1<Self::Float>,
503    ) {
504        dydt[[0]] = y[[3]];
505        dydt[[1]] = y[[4]];
506        dydt[[2]] = y[[5]];
507        let r = Vec3::new(y[[0]], y[[1]], y[[2]]);
508        let v = Vec3::new(y[[3]], y[[4]], y[[5]]);
509        (dydt[[3]], dydt[[4]], dydt[[5]]) = rocacc(self.q, &r, &v);
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn strinit_stradv_test() -> Result<(), RocheError> {
519        // Values from trm.roche.bspot
520        let (mut r, mut v) = strinit(0.2)?;
521        let _time = stradv(0.2, &mut r, &mut v, 0.3, 1.0e-7, 1.0e-3);
522        assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
523        assert!((v - Vec3::new(-1.4769457229627583, 0.31712381217252994, 0.0)).length() < 1.0e-7);
524        Ok(())
525    }
526
527    #[test]
528    fn stream_test() -> Result<(), RocheError> {
529        // Values from trm.roche.stream
530        let (x, y) = stream(0.2, 0.01, 200)?;
531        assert!((x[0] - 0.6585557).hypot(y[0] - 0.0) < 1.0e-4);
532        assert!((x[50] - 0.18384902).hypot(y[50] - 0.15145306) < 1.0e-4);
533        assert!((x[100] - -0.100431986).hypot(y[100] - -0.13697079) < 1.0e-4);
534        assert!((x[150] - 0.21720248).hypot(y[150] - -0.4577784) < 1.0e-4);
535        assert!((x[y.len() - 1] - 0.15403406).hypot(y[y.len() - 1] - 0.016731631) < 1.0e-4);
536        assert!(stream(-0.2, 0.0001, 200).is_err());
537        assert!(stream(0.2, 1.1, 200).is_err());
538        assert!(stream(0.2, -0.1, 200).is_err());
539        assert!(stream(0.2, 0.0001, 1).is_err());
540        Ok(())
541    }
542
543    #[test]
544    fn strmnx_test() -> Result<(), RocheError> {
545        // Values from trm.roche.strmnx
546        let (x, y, vx1, vy1, vx2, vy2) = strmnx_wrapper(0.2, 1, 1.0e-7)?;
547        assert!(
548            (x - -0.08613947462186848).hypot(y - 0.05411592729509131)
549                / (-0.08613947462186848_f64).hypot(0.05411592729509131)
550                < 1.0e-6
551        );
552        assert!(
553            (vx1 - -1.9727409465489645).hypot(vy1 - -3.30679322752132)
554                / (-1.9727409465489645_f64).hypot(-3.30679322752132)
555                < 1.0e-6
556        );
557        assert!(
558            (vx2 - -1.5225623467338747).hypot(vy2 - -2.5902178683586605)
559                / (-1.5225623467338747_f64).hypot(-2.5902178683586605)
560                < 1.0e-6
561        );
562        Ok(())
563    }
564
565    #[test]
566    fn brightspot_position_test() -> Result<(), RocheError> {
567        // Values from trm.roche.bspot
568        let r = brightspot_position(0.2, 0.3, 1.0e-7, 1.0e-3)?;
569        assert!((r - Vec3::new(0.2660591412807423, 0.13860932478255575, 0.0)).length() < 1.0e-7);
570        Ok(())
571    }
572}