Skip to main content

lox_orbits/propagators/
numerical.rs

1// SPDX-FileCopyrightText: 2025 Helge Eichhorn <git@helgeeichhorn.de>
2//
3// SPDX-License-Identifier: MPL-2.0
4
5use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub};
6
7use differential_equations::{
8    interpolate::Interpolation,
9    ode::{ODE, ODEProblem, OrdinaryNumericalMethod},
10    prelude::ExplicitRungeKutta,
11    traits::State,
12};
13use glam::DVec3;
14use lox_bodies::{
15    DynOrigin, J2, MeanRadius, Origin, PointMass, TryJ2, TryMeanRadius, TryPointMass,
16    UndefinedOriginPropertyError,
17};
18use lox_core::coords::Cartesian;
19use lox_frames::{DynFrame, ReferenceFrame};
20use lox_time::Time;
21use lox_time::deltas::TimeDelta;
22use lox_time::intervals::TimeInterval;
23use lox_time::time_scales::{DynTimeScale, TimeScale};
24use thiserror::Error;
25
26use crate::orbits::{CartesianOrbit, TrajectorError, Trajectory};
27use crate::propagators::Propagator;
28
29/// Number of maximum-size integration steps per characteristic orbital timescale (r/v).
30/// Since r/v ≈ T/(2π) for a circular orbit, this yields ~50 steps per orbit.
31const H_MAX_STEPS_PER_TIMESCALE: f64 = 8.0;
32
33/// Errors that can occur during J2-perturbed numerical propagation.
34#[derive(Debug, Error)]
35pub enum J2Error {
36    /// The ODE solver failed with the given message.
37    #[error("ODE solver failed: {0}")]
38    Solver(String),
39    /// The ODE solver returned an empty solution.
40    #[error("ODE solver returned no solution")]
41    EmptySolution,
42    /// Fewer than two time steps were provided to `propagate_to`.
43    #[error("at least two time steps are needed")]
44    InvalidTimeSteps,
45    /// Error constructing the output trajectory.
46    #[error(transparent)]
47    Trajectory(#[from] TrajectorError),
48}
49
50/// Numerical orbit propagator with J2 zonal harmonic perturbation.
51#[derive(Debug, Clone, Copy)]
52pub struct J2Propagator<T: TimeScale, O: TryJ2 + TryPointMass + TryMeanRadius, R: ReferenceFrame> {
53    initial_state: CartesianOrbit<T, O, R>,
54    rtol: f64,
55    atol: f64,
56    h_max: f64,
57    h_min: f64,
58    max_steps: usize,
59}
60
61/// Type alias for a [`J2Propagator`] using dynamic time scale, origin, and frame.
62pub type DynJ2Propagator = J2Propagator<DynTimeScale, DynOrigin, DynFrame>;
63
64fn default_h_max(position: DVec3, velocity: DVec3) -> f64 {
65    position.length() / velocity.length() / H_MAX_STEPS_PER_TIMESCALE
66}
67
68// Infallible — static bounds
69impl<T, O, R> J2Propagator<T, O, R>
70where
71    T: TimeScale,
72    O: J2 + PointMass + MeanRadius + Copy,
73    R: ReferenceFrame,
74{
75    /// Create a new J2 propagator from the given initial state.
76    pub fn new(initial_state: CartesianOrbit<T, O, R>) -> Self {
77        let h_max = default_h_max(initial_state.position(), initial_state.velocity());
78        Self {
79            initial_state,
80            rtol: 1e-8,
81            atol: 1e-6,
82            h_max,
83            h_min: 1e-6,
84            max_steps: 100_000,
85        }
86    }
87}
88
89// Fallible — Try* bounds (covers DynOrigin)
90impl<T, O, R> J2Propagator<T, O, R>
91where
92    T: TimeScale,
93    O: TryJ2 + TryPointMass + TryMeanRadius + Copy,
94    R: ReferenceFrame,
95{
96    /// Try to create a new J2 propagator, returning an error if the origin lacks required properties.
97    pub fn try_new(
98        initial_state: CartesianOrbit<T, O, R>,
99    ) -> Result<Self, UndefinedOriginPropertyError> {
100        initial_state.origin().try_gravitational_parameter()?;
101        initial_state.origin().try_j2()?;
102        initial_state.origin().try_mean_radius()?;
103
104        let h_max = default_h_max(initial_state.position(), initial_state.velocity());
105        Ok(Self {
106            initial_state,
107            rtol: 1e-8,
108            atol: 1e-6,
109            h_max,
110            h_min: 1e-6,
111            max_steps: 100_000,
112        })
113    }
114}
115
116impl<T, O, R> J2Propagator<T, O, R>
117where
118    T: TimeScale,
119    O: TryJ2 + TryPointMass + TryMeanRadius + Copy,
120    R: ReferenceFrame,
121{
122    /// Set the relative tolerance for the ODE solver.
123    pub fn with_rtol(mut self, rtol: f64) -> Self {
124        self.rtol = rtol;
125        self
126    }
127
128    /// Set the absolute tolerance for the ODE solver.
129    pub fn with_atol(mut self, atol: f64) -> Self {
130        self.atol = atol;
131        self
132    }
133
134    /// Set the maximum integration step size in seconds.
135    pub fn with_h_max(mut self, h_max: f64) -> Self {
136        self.h_max = h_max;
137        self
138    }
139
140    /// Set the minimum integration step size in seconds.
141    pub fn with_h_min(mut self, h_min: f64) -> Self {
142        self.h_min = h_min;
143        self
144    }
145
146    /// Set the maximum number of integration steps.
147    pub fn with_max_steps(mut self, max_steps: usize) -> Self {
148        self.max_steps = max_steps;
149        self
150    }
151
152    /// Return a reference to the initial orbital state.
153    pub fn initial_state(&self) -> &CartesianOrbit<T, O, R> {
154        &self.initial_state
155    }
156
157    /// Return the central body origin.
158    pub fn origin(&self) -> O {
159        self.initial_state.origin()
160    }
161
162    /// Return the reference frame.
163    pub fn reference_frame(&self) -> R
164    where
165        R: Copy,
166    {
167        self.initial_state.reference_frame()
168    }
169
170    fn gravitational_parameter(&self) -> f64 {
171        self.initial_state
172            .origin()
173            .try_gravitational_parameter()
174            .expect("gravitational parameter should be available")
175            .as_f64()
176    }
177
178    fn j2(&self) -> f64 {
179        self.initial_state
180            .origin()
181            .try_j2()
182            .expect("J2 should be available")
183    }
184
185    fn mean_radius(&self) -> f64 {
186        self.initial_state
187            .origin()
188            .try_mean_radius()
189            .expect("mean radius should be available")
190            .as_f64()
191    }
192
193    fn solver(
194        &self,
195    ) -> impl OrdinaryNumericalMethod<f64, CartesianState> + Interpolation<f64, CartesianState>
196    {
197        ExplicitRungeKutta::dop853()
198            .rtol(self.rtol)
199            .atol(self.atol)
200            .h_max(self.h_max)
201            .h_min(self.h_min)
202            .max_steps(self.max_steps)
203    }
204}
205
206impl<T, O, R> ODE<f64, CartesianState> for J2Propagator<T, O, R>
207where
208    T: TimeScale,
209    O: TryJ2 + TryPointMass + TryMeanRadius + Copy,
210    R: ReferenceFrame,
211{
212    fn diff(&self, _t: f64, s: &CartesianState, dydt: &mut CartesianState) {
213        let mu = self.gravitational_parameter();
214        let j2 = self.j2();
215        let rm = self.mean_radius();
216
217        let p = s.position();
218        let pm = p.length();
219        let pj = -3.0 / 2.0 * mu * j2 * rm.powi(2) / pm.powi(5);
220
221        let acc = -mu * p / pm.powi(3)
222            + pj * p * (DVec3::new(1.0, 1.0, 3.0) - 5.0 * p.z.powi(2) / pm.powi(2));
223
224        dydt.0.set_position(s.velocity());
225        dydt.0.set_velocity(acc);
226    }
227}
228
229impl<T, O, R> Propagator<T, O> for J2Propagator<T, O, R>
230where
231    T: TimeScale + Copy + PartialOrd,
232    O: TryJ2 + TryPointMass + TryMeanRadius + Origin + Copy,
233    R: ReferenceFrame + Copy,
234{
235    type Frame = R;
236    type Error = J2Error;
237
238    fn state_at(&self, time: Time<T>) -> Result<CartesianOrbit<T, O, R>, J2Error> {
239        let epoch = self.initial_state.time();
240        let t0 = 0.0_f64;
241        let t1 = (time - epoch).to_seconds().to_f64();
242        let s0 = CartesianState::from(*self.initial_state());
243
244        let mut solver = self.solver();
245
246        let problem = ODEProblem::new(self, t0, t1, s0);
247        let solution = problem
248            .solve(&mut solver)
249            .map_err(|e| J2Error::Solver(format!("{:?}", e)))?;
250
251        let (_, final_state) = solution.iter().next_back().ok_or(J2Error::EmptySolution)?;
252
253        let origin = self.initial_state.origin();
254        let frame = self.initial_state.reference_frame();
255        Ok(CartesianOrbit::new(final_state.0, time, origin, frame))
256    }
257
258    fn propagate(&self, interval: TimeInterval<T>) -> Result<Trajectory<T, O, R>, J2Error> {
259        let start = interval.start();
260
261        // Propagate to start of interval
262        let s0: CartesianState = if start != self.initial_state.time() {
263            self.state_at(start)?
264        } else {
265            *self.initial_state()
266        }
267        .into();
268
269        let t1 = (interval.end() - start).to_seconds().to_f64();
270
271        let mut solver = self.solver();
272
273        let problem = ODEProblem::new(self, 0.0, t1, s0);
274        let solution = problem
275            .solve(&mut solver)
276            .map_err(|e| J2Error::Solver(format!("{:?}", e)))?;
277
278        let origin = self.initial_state.origin();
279        let frame = self.initial_state.reference_frame();
280
281        Ok(solution
282            .iter()
283            .map(|(t, s)| {
284                CartesianOrbit::new(s.0, start + TimeDelta::from_seconds_f64(*t), origin, frame)
285            })
286            .collect())
287    }
288
289    fn propagate_to(
290        &self,
291        times: impl IntoIterator<Item = Time<T>>,
292    ) -> Result<Trajectory<T, O, Self::Frame>, Self::Error> {
293        let times: Vec<Time<T>> = times.into_iter().collect();
294        if times.len() < 2 {
295            return Err(J2Error::InvalidTimeSteps);
296        }
297
298        let t0 = times[0];
299        let steps: Vec<f64> = times
300            .iter()
301            .map(|t| (*t - t0).to_seconds().to_f64())
302            .collect();
303        let t1 = *steps.last().unwrap();
304
305        // Propagate to first time step
306        let s0: CartesianState = if t0 != self.initial_state.time() {
307            self.state_at(t0)?
308        } else {
309            *self.initial_state()
310        }
311        .into();
312
313        let mut solver = self.solver();
314
315        let problem = ODEProblem::new(self, 0.0, t1, s0);
316        let solution = problem
317            .t_eval(steps)
318            .solve(&mut solver)
319            .map_err(|e| J2Error::Solver(format!("{:?}", e)))?;
320
321        let origin = self.initial_state.origin();
322        let frame = self.initial_state.reference_frame();
323
324        Ok(solution
325            .iter()
326            .map(|(t, s)| {
327                CartesianOrbit::new(s.0, t0 + TimeDelta::from_seconds_f64(*t), origin, frame)
328            })
329            .collect())
330    }
331}
332
333/// A six-element Cartesian state vector (position + velocity) used as the ODE state.
334#[derive(Debug, Clone, Copy, Default, PartialEq)]
335#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
336pub struct CartesianState(Cartesian);
337
338impl CartesianState {
339    fn position(&self) -> DVec3 {
340        self.0.position()
341    }
342
343    fn velocity(&self) -> DVec3 {
344        self.0.velocity()
345    }
346}
347
348impl State<f64> for CartesianState {
349    fn len(&self) -> usize {
350        6
351    }
352
353    fn get(&self, i: usize) -> f64 {
354        match i {
355            0 => self.0.position().x,
356            1 => self.0.position().y,
357            2 => self.0.position().z,
358            3 => self.0.velocity().x,
359            4 => self.0.velocity().y,
360            5 => self.0.velocity().z,
361            _ => unreachable!("index out of bounds"),
362        }
363    }
364
365    fn set(&mut self, i: usize, value: f64) {
366        match i {
367            0 => {
368                self.0.set::<0>(value);
369            }
370            1 => {
371                self.0.set::<1>(value);
372            }
373            2 => {
374                self.0.set::<2>(value);
375            }
376            3 => {
377                self.0.set::<3>(value);
378            }
379            4 => {
380                self.0.set::<4>(value);
381            }
382            5 => {
383                self.0.set::<5>(value);
384            }
385            _ => unreachable!("index out of bounds"),
386        };
387    }
388
389    fn zeros() -> Self {
390        Self::default()
391    }
392}
393
394impl Add for CartesianState {
395    type Output = Self;
396
397    fn add(self, rhs: Self) -> Self::Output {
398        CartesianState(self.0 + rhs.0)
399    }
400}
401
402impl AddAssign for CartesianState {
403    fn add_assign(&mut self, rhs: Self) {
404        self.0 += rhs.0;
405    }
406}
407
408impl Sub for CartesianState {
409    type Output = Self;
410
411    fn sub(self, rhs: Self) -> Self::Output {
412        CartesianState(self.0 - rhs.0)
413    }
414}
415
416impl Mul<f64> for CartesianState {
417    type Output = Self;
418
419    fn mul(self, rhs: f64) -> Self::Output {
420        CartesianState(self.0 * rhs)
421    }
422}
423
424impl Div<f64> for CartesianState {
425    type Output = Self;
426
427    fn div(self, rhs: f64) -> Self::Output {
428        CartesianState(self.0 / rhs)
429    }
430}
431
432impl Neg for CartesianState {
433    type Output = Self;
434
435    fn neg(self) -> Self::Output {
436        CartesianState(-self.0)
437    }
438}
439
440impl<T, O, R> From<CartesianOrbit<T, O, R>> for CartesianState
441where
442    T: TimeScale,
443    O: Origin,
444    R: ReferenceFrame,
445{
446    fn from(orbit: CartesianOrbit<T, O, R>) -> Self {
447        Self(orbit.state())
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use lox_bodies::Earth;
454    use lox_frames::Icrf;
455    use lox_test_utils::assert_approx_eq;
456    use lox_time::Time;
457    use lox_time::intervals::Interval;
458    use lox_time::time;
459    use lox_time::time_scales::Tdb;
460    use lox_units::{DistanceUnits, VelocityUnits};
461
462    use super::*;
463
464    fn initial_state() -> CartesianOrbit<Tdb, Earth, Icrf> {
465        let time = time!(Tdb, 2023, 1, 1).unwrap();
466        CartesianOrbit::new(
467            Cartesian::new(
468                1131.340.km(),
469                -2282.343.km(),
470                6672.423.km(),
471                -5.64305.kps(),
472                4.30333.kps(),
473                2.42879.kps(),
474            ),
475            time,
476            Earth,
477            Icrf,
478        )
479    }
480
481    #[test]
482    fn test_j2_ode() {
483        let s0_orbit = initial_state();
484        let j2 = J2Propagator::new(s0_orbit);
485
486        let s0 = CartesianState(Cartesian::new(
487            1131.340.km(),
488            -2282.343.km(),
489            6672.423.km(),
490            -5.64305.kps(),
491            4.30333.kps(),
492            2.42879.kps(),
493        ));
494        let mut dsdt = CartesianState::default();
495        j2.diff(0.0, &s0, &mut dsdt);
496
497        let acc_exp = DVec3::new(-1.2324031762444367, 2.4862258582559233, -7.287340551142344);
498
499        assert_eq!(dsdt.position(), s0.velocity());
500        assert_approx_eq!(dsdt.velocity(), acc_exp, rtol <= 1e-8);
501    }
502
503    #[test]
504    fn test_j2_propagator() {
505        let s0_orbit = initial_state();
506        let time = s0_orbit.time();
507        let j2 = J2Propagator::new(s0_orbit);
508        let dt = TimeDelta::from_minutes(40);
509        let interval = Interval::new(time, time + dt);
510        let traj = j2.propagate(interval).unwrap();
511        let s1 = traj.states().into_iter().last().unwrap();
512        let p_act = s1.position();
513        let v_act = s1.velocity();
514        let p_exp = DVec3::new(
515            -4255.223590627231e3,
516            4384.471704756651e3,
517            -3.936_135_007_962_321e6,
518        );
519        let v_exp = DVec3::new(
520            3.6559899898490054e3,
521            -1.884445831960271e3,
522            -6.123308149589636e3,
523        );
524        assert_approx_eq!(p_act, p_exp, rtol <= 1e-1);
525        assert_approx_eq!(v_act, v_exp, rtol <= 1e-1);
526    }
527
528    /// Propagating [epoch, epoch+40m] and [epoch+20m, epoch+40m] should
529    /// produce the same final state.
530    #[test]
531    fn test_propagate_with_offset_interval() {
532        let s0_orbit = initial_state();
533        let epoch = s0_orbit.time();
534        let j2 = J2Propagator::new(s0_orbit);
535
536        let dt = TimeDelta::from_minutes(40);
537        let offset = TimeDelta::from_minutes(20);
538
539        // Full interval from epoch
540        let full = Interval::new(epoch, epoch + dt);
541        let traj_full = j2.propagate(full).unwrap();
542        let s_full = traj_full.states().into_iter().last().unwrap();
543
544        // Offset interval starting 20 minutes after epoch
545        let offset_interval = Interval::new(epoch + offset, epoch + dt);
546        let traj_offset = j2.propagate(offset_interval).unwrap();
547        let s_offset = traj_offset.states().into_iter().last().unwrap();
548
549        // Final states should match
550        assert_approx_eq!(s_full.position(), s_offset.position(), rtol <= 1e-6);
551        assert_approx_eq!(s_full.velocity(), s_offset.velocity(), rtol <= 1e-6);
552
553        // Trajectory timestamps should be consistent with the interval
554        assert_eq!(traj_offset.start_time(), epoch + offset);
555    }
556
557    /// `state_at` and `propagate` should agree on the final state.
558    #[test]
559    fn test_state_at_matches_propagate() {
560        let s0_orbit = initial_state();
561        let epoch = s0_orbit.time();
562        let j2 = J2Propagator::new(s0_orbit);
563
564        let target = epoch + TimeDelta::from_minutes(40);
565        let state = j2.state_at(target).unwrap();
566
567        let interval = Interval::new(epoch, target);
568        let traj = j2.propagate(interval).unwrap();
569        let last = traj.states().into_iter().last().unwrap();
570
571        assert_approx_eq!(state.position(), last.position(), rtol <= 1e-6);
572        assert_approx_eq!(state.velocity(), last.velocity(), rtol <= 1e-6);
573    }
574
575    #[test]
576    fn test_propagate_to() {
577        let s0_orbit = initial_state();
578        let epoch = s0_orbit.time();
579        let j2 = J2Propagator::new(s0_orbit);
580
581        let dt = TimeDelta::from_minutes(40);
582        let interval = Interval::new(epoch, epoch + dt);
583        let times: Vec<_> = interval.step_by(TimeDelta::from_minutes(10)).collect();
584
585        let traj = j2.propagate_to(times.clone()).unwrap();
586        let states = traj.states();
587
588        // Should have exactly as many states as requested times
589        assert_eq!(states.len(), times.len());
590
591        // First state should match the initial state
592        assert_approx_eq!(states[0].position(), s0_orbit.position(), rtol <= 1e-10);
593
594        // Last state should match state_at for the same time
595        let last_time = *times.last().unwrap();
596        let expected = j2.state_at(last_time).unwrap();
597        assert_approx_eq!(
598            states.last().unwrap().position(),
599            expected.position(),
600            rtol <= 1e-6
601        );
602    }
603
604    /// `propagate_to` with times starting after epoch should produce the
605    /// same final state as propagating from epoch.
606    #[test]
607    fn test_propagate_to_with_offset_times() {
608        let s0_orbit = initial_state();
609        let epoch = s0_orbit.time();
610        let j2 = J2Propagator::new(s0_orbit);
611
612        let start = epoch + TimeDelta::from_minutes(20);
613        let end = start + TimeDelta::from_minutes(20);
614        let interval = Interval::new(start, end);
615        let times: Vec<_> = interval.step_by(TimeDelta::from_minutes(5)).collect();
616
617        let traj = j2.propagate_to(times.clone()).unwrap();
618        let states = traj.states();
619
620        assert_eq!(states.len(), times.len());
621
622        // First state should match state_at for the offset time
623        let expected_first = j2.state_at(times[0]).unwrap();
624        assert_approx_eq!(
625            states[0].position(),
626            expected_first.position(),
627            rtol <= 1e-6
628        );
629
630        // Last state should match state_at
631        let expected_last = j2.state_at(*times.last().unwrap()).unwrap();
632        assert_approx_eq!(
633            states.last().unwrap().position(),
634            expected_last.position(),
635            rtol <= 1e-6
636        );
637    }
638
639    #[test]
640    fn test_propagate_to_too_few_times() {
641        let s0_orbit = initial_state();
642        let j2 = J2Propagator::new(s0_orbit);
643
644        // Empty
645        let result = j2.propagate_to(vec![]);
646        assert!(result.is_err());
647
648        // Single element
649        let result = j2.propagate_to(vec![s0_orbit.time()]);
650        assert!(result.is_err());
651    }
652}