1use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub};
6
7use differential_equations::{
8 interpolate::Interpolation,
9 ode::{ODE, ODEProblem, OrdinaryNumericalMethod},
10 prelude::ExplicitRungeKutta,
11 traits::State,
12};
13use glam::DVec3;
14use lox_bodies::{
15 DynOrigin, J2, MeanRadius, Origin, PointMass, TryJ2, TryMeanRadius, TryPointMass,
16 UndefinedOriginPropertyError,
17};
18use lox_core::coords::Cartesian;
19use lox_frames::{DynFrame, ReferenceFrame};
20use lox_time::Time;
21use lox_time::deltas::TimeDelta;
22use lox_time::intervals::TimeInterval;
23use lox_time::time_scales::{DynTimeScale, TimeScale};
24use thiserror::Error;
25
26use crate::orbits::{CartesianOrbit, TrajectorError, Trajectory};
27use crate::propagators::Propagator;
28
29const H_MAX_STEPS_PER_TIMESCALE: f64 = 8.0;
32
33#[derive(Debug, Error)]
35pub enum J2Error {
36 #[error("ODE solver failed: {0}")]
38 Solver(String),
39 #[error("ODE solver returned no solution")]
41 EmptySolution,
42 #[error("at least two time steps are needed")]
44 InvalidTimeSteps,
45 #[error(transparent)]
47 Trajectory(#[from] TrajectorError),
48}
49
50#[derive(Debug, Clone, Copy)]
52pub struct J2Propagator<T: TimeScale, O: TryJ2 + TryPointMass + TryMeanRadius, R: ReferenceFrame> {
53 initial_state: CartesianOrbit<T, O, R>,
54 rtol: f64,
55 atol: f64,
56 h_max: f64,
57 h_min: f64,
58 max_steps: usize,
59}
60
61pub type DynJ2Propagator = J2Propagator<DynTimeScale, DynOrigin, DynFrame>;
63
64fn default_h_max(position: DVec3, velocity: DVec3) -> f64 {
65 position.length() / velocity.length() / H_MAX_STEPS_PER_TIMESCALE
66}
67
68impl<T, O, R> J2Propagator<T, O, R>
70where
71 T: TimeScale,
72 O: J2 + PointMass + MeanRadius + Copy,
73 R: ReferenceFrame,
74{
75 pub fn new(initial_state: CartesianOrbit<T, O, R>) -> Self {
77 let h_max = default_h_max(initial_state.position(), initial_state.velocity());
78 Self {
79 initial_state,
80 rtol: 1e-8,
81 atol: 1e-6,
82 h_max,
83 h_min: 1e-6,
84 max_steps: 100_000,
85 }
86 }
87}
88
89impl<T, O, R> J2Propagator<T, O, R>
91where
92 T: TimeScale,
93 O: TryJ2 + TryPointMass + TryMeanRadius + Copy,
94 R: ReferenceFrame,
95{
96 pub fn try_new(
98 initial_state: CartesianOrbit<T, O, R>,
99 ) -> Result<Self, UndefinedOriginPropertyError> {
100 initial_state.origin().try_gravitational_parameter()?;
101 initial_state.origin().try_j2()?;
102 initial_state.origin().try_mean_radius()?;
103
104 let h_max = default_h_max(initial_state.position(), initial_state.velocity());
105 Ok(Self {
106 initial_state,
107 rtol: 1e-8,
108 atol: 1e-6,
109 h_max,
110 h_min: 1e-6,
111 max_steps: 100_000,
112 })
113 }
114}
115
116impl<T, O, R> J2Propagator<T, O, R>
117where
118 T: TimeScale,
119 O: TryJ2 + TryPointMass + TryMeanRadius + Copy,
120 R: ReferenceFrame,
121{
122 pub fn with_rtol(mut self, rtol: f64) -> Self {
124 self.rtol = rtol;
125 self
126 }
127
128 pub fn with_atol(mut self, atol: f64) -> Self {
130 self.atol = atol;
131 self
132 }
133
134 pub fn with_h_max(mut self, h_max: f64) -> Self {
136 self.h_max = h_max;
137 self
138 }
139
140 pub fn with_h_min(mut self, h_min: f64) -> Self {
142 self.h_min = h_min;
143 self
144 }
145
146 pub fn with_max_steps(mut self, max_steps: usize) -> Self {
148 self.max_steps = max_steps;
149 self
150 }
151
152 pub fn initial_state(&self) -> &CartesianOrbit<T, O, R> {
154 &self.initial_state
155 }
156
157 pub fn origin(&self) -> O {
159 self.initial_state.origin()
160 }
161
162 pub fn reference_frame(&self) -> R
164 where
165 R: Copy,
166 {
167 self.initial_state.reference_frame()
168 }
169
170 fn gravitational_parameter(&self) -> f64 {
171 self.initial_state
172 .origin()
173 .try_gravitational_parameter()
174 .expect("gravitational parameter should be available")
175 .as_f64()
176 }
177
178 fn j2(&self) -> f64 {
179 self.initial_state
180 .origin()
181 .try_j2()
182 .expect("J2 should be available")
183 }
184
185 fn mean_radius(&self) -> f64 {
186 self.initial_state
187 .origin()
188 .try_mean_radius()
189 .expect("mean radius should be available")
190 .as_f64()
191 }
192
193 fn solver(
194 &self,
195 ) -> impl OrdinaryNumericalMethod<f64, CartesianState> + Interpolation<f64, CartesianState>
196 {
197 ExplicitRungeKutta::dop853()
198 .rtol(self.rtol)
199 .atol(self.atol)
200 .h_max(self.h_max)
201 .h_min(self.h_min)
202 .max_steps(self.max_steps)
203 }
204}
205
206impl<T, O, R> ODE<f64, CartesianState> for J2Propagator<T, O, R>
207where
208 T: TimeScale,
209 O: TryJ2 + TryPointMass + TryMeanRadius + Copy,
210 R: ReferenceFrame,
211{
212 fn diff(&self, _t: f64, s: &CartesianState, dydt: &mut CartesianState) {
213 let mu = self.gravitational_parameter();
214 let j2 = self.j2();
215 let rm = self.mean_radius();
216
217 let p = s.position();
218 let pm = p.length();
219 let pj = -3.0 / 2.0 * mu * j2 * rm.powi(2) / pm.powi(5);
220
221 let acc = -mu * p / pm.powi(3)
222 + pj * p * (DVec3::new(1.0, 1.0, 3.0) - 5.0 * p.z.powi(2) / pm.powi(2));
223
224 dydt.0.set_position(s.velocity());
225 dydt.0.set_velocity(acc);
226 }
227}
228
229impl<T, O, R> Propagator<T, O> for J2Propagator<T, O, R>
230where
231 T: TimeScale + Copy + PartialOrd,
232 O: TryJ2 + TryPointMass + TryMeanRadius + Origin + Copy,
233 R: ReferenceFrame + Copy,
234{
235 type Frame = R;
236 type Error = J2Error;
237
238 fn state_at(&self, time: Time<T>) -> Result<CartesianOrbit<T, O, R>, J2Error> {
239 let epoch = self.initial_state.time();
240 let t0 = 0.0_f64;
241 let t1 = (time - epoch).to_seconds().to_f64();
242 let s0 = CartesianState::from(*self.initial_state());
243
244 let mut solver = self.solver();
245
246 let problem = ODEProblem::new(self, t0, t1, s0);
247 let solution = problem
248 .solve(&mut solver)
249 .map_err(|e| J2Error::Solver(format!("{:?}", e)))?;
250
251 let (_, final_state) = solution.iter().next_back().ok_or(J2Error::EmptySolution)?;
252
253 let origin = self.initial_state.origin();
254 let frame = self.initial_state.reference_frame();
255 Ok(CartesianOrbit::new(final_state.0, time, origin, frame))
256 }
257
258 fn propagate(&self, interval: TimeInterval<T>) -> Result<Trajectory<T, O, R>, J2Error> {
259 let start = interval.start();
260
261 let s0: CartesianState = if start != self.initial_state.time() {
263 self.state_at(start)?
264 } else {
265 *self.initial_state()
266 }
267 .into();
268
269 let t1 = (interval.end() - start).to_seconds().to_f64();
270
271 let mut solver = self.solver();
272
273 let problem = ODEProblem::new(self, 0.0, t1, s0);
274 let solution = problem
275 .solve(&mut solver)
276 .map_err(|e| J2Error::Solver(format!("{:?}", e)))?;
277
278 let origin = self.initial_state.origin();
279 let frame = self.initial_state.reference_frame();
280
281 Ok(solution
282 .iter()
283 .map(|(t, s)| {
284 CartesianOrbit::new(s.0, start + TimeDelta::from_seconds_f64(*t), origin, frame)
285 })
286 .collect())
287 }
288
289 fn propagate_to(
290 &self,
291 times: impl IntoIterator<Item = Time<T>>,
292 ) -> Result<Trajectory<T, O, Self::Frame>, Self::Error> {
293 let times: Vec<Time<T>> = times.into_iter().collect();
294 if times.len() < 2 {
295 return Err(J2Error::InvalidTimeSteps);
296 }
297
298 let t0 = times[0];
299 let steps: Vec<f64> = times
300 .iter()
301 .map(|t| (*t - t0).to_seconds().to_f64())
302 .collect();
303 let t1 = *steps.last().unwrap();
304
305 let s0: CartesianState = if t0 != self.initial_state.time() {
307 self.state_at(t0)?
308 } else {
309 *self.initial_state()
310 }
311 .into();
312
313 let mut solver = self.solver();
314
315 let problem = ODEProblem::new(self, 0.0, t1, s0);
316 let solution = problem
317 .t_eval(steps)
318 .solve(&mut solver)
319 .map_err(|e| J2Error::Solver(format!("{:?}", e)))?;
320
321 let origin = self.initial_state.origin();
322 let frame = self.initial_state.reference_frame();
323
324 Ok(solution
325 .iter()
326 .map(|(t, s)| {
327 CartesianOrbit::new(s.0, t0 + TimeDelta::from_seconds_f64(*t), origin, frame)
328 })
329 .collect())
330 }
331}
332
333#[derive(Debug, Clone, Copy, Default, PartialEq)]
335#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
336pub struct CartesianState(Cartesian);
337
338impl CartesianState {
339 fn position(&self) -> DVec3 {
340 self.0.position()
341 }
342
343 fn velocity(&self) -> DVec3 {
344 self.0.velocity()
345 }
346}
347
348impl State<f64> for CartesianState {
349 fn len(&self) -> usize {
350 6
351 }
352
353 fn get(&self, i: usize) -> f64 {
354 match i {
355 0 => self.0.position().x,
356 1 => self.0.position().y,
357 2 => self.0.position().z,
358 3 => self.0.velocity().x,
359 4 => self.0.velocity().y,
360 5 => self.0.velocity().z,
361 _ => unreachable!("index out of bounds"),
362 }
363 }
364
365 fn set(&mut self, i: usize, value: f64) {
366 match i {
367 0 => {
368 self.0.set::<0>(value);
369 }
370 1 => {
371 self.0.set::<1>(value);
372 }
373 2 => {
374 self.0.set::<2>(value);
375 }
376 3 => {
377 self.0.set::<3>(value);
378 }
379 4 => {
380 self.0.set::<4>(value);
381 }
382 5 => {
383 self.0.set::<5>(value);
384 }
385 _ => unreachable!("index out of bounds"),
386 };
387 }
388
389 fn zeros() -> Self {
390 Self::default()
391 }
392}
393
394impl Add for CartesianState {
395 type Output = Self;
396
397 fn add(self, rhs: Self) -> Self::Output {
398 CartesianState(self.0 + rhs.0)
399 }
400}
401
402impl AddAssign for CartesianState {
403 fn add_assign(&mut self, rhs: Self) {
404 self.0 += rhs.0;
405 }
406}
407
408impl Sub for CartesianState {
409 type Output = Self;
410
411 fn sub(self, rhs: Self) -> Self::Output {
412 CartesianState(self.0 - rhs.0)
413 }
414}
415
416impl Mul<f64> for CartesianState {
417 type Output = Self;
418
419 fn mul(self, rhs: f64) -> Self::Output {
420 CartesianState(self.0 * rhs)
421 }
422}
423
424impl Div<f64> for CartesianState {
425 type Output = Self;
426
427 fn div(self, rhs: f64) -> Self::Output {
428 CartesianState(self.0 / rhs)
429 }
430}
431
432impl Neg for CartesianState {
433 type Output = Self;
434
435 fn neg(self) -> Self::Output {
436 CartesianState(-self.0)
437 }
438}
439
440impl<T, O, R> From<CartesianOrbit<T, O, R>> for CartesianState
441where
442 T: TimeScale,
443 O: Origin,
444 R: ReferenceFrame,
445{
446 fn from(orbit: CartesianOrbit<T, O, R>) -> Self {
447 Self(orbit.state())
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use lox_bodies::Earth;
454 use lox_frames::Icrf;
455 use lox_test_utils::assert_approx_eq;
456 use lox_time::Time;
457 use lox_time::intervals::Interval;
458 use lox_time::time;
459 use lox_time::time_scales::Tdb;
460 use lox_units::{DistanceUnits, VelocityUnits};
461
462 use super::*;
463
464 fn initial_state() -> CartesianOrbit<Tdb, Earth, Icrf> {
465 let time = time!(Tdb, 2023, 1, 1).unwrap();
466 CartesianOrbit::new(
467 Cartesian::new(
468 1131.340.km(),
469 -2282.343.km(),
470 6672.423.km(),
471 -5.64305.kps(),
472 4.30333.kps(),
473 2.42879.kps(),
474 ),
475 time,
476 Earth,
477 Icrf,
478 )
479 }
480
481 #[test]
482 fn test_j2_ode() {
483 let s0_orbit = initial_state();
484 let j2 = J2Propagator::new(s0_orbit);
485
486 let s0 = CartesianState(Cartesian::new(
487 1131.340.km(),
488 -2282.343.km(),
489 6672.423.km(),
490 -5.64305.kps(),
491 4.30333.kps(),
492 2.42879.kps(),
493 ));
494 let mut dsdt = CartesianState::default();
495 j2.diff(0.0, &s0, &mut dsdt);
496
497 let acc_exp = DVec3::new(-1.2324031762444367, 2.4862258582559233, -7.287340551142344);
498
499 assert_eq!(dsdt.position(), s0.velocity());
500 assert_approx_eq!(dsdt.velocity(), acc_exp, rtol <= 1e-8);
501 }
502
503 #[test]
504 fn test_j2_propagator() {
505 let s0_orbit = initial_state();
506 let time = s0_orbit.time();
507 let j2 = J2Propagator::new(s0_orbit);
508 let dt = TimeDelta::from_minutes(40);
509 let interval = Interval::new(time, time + dt);
510 let traj = j2.propagate(interval).unwrap();
511 let s1 = traj.states().into_iter().last().unwrap();
512 let p_act = s1.position();
513 let v_act = s1.velocity();
514 let p_exp = DVec3::new(
515 -4255.223590627231e3,
516 4384.471704756651e3,
517 -3.936_135_007_962_321e6,
518 );
519 let v_exp = DVec3::new(
520 3.6559899898490054e3,
521 -1.884445831960271e3,
522 -6.123308149589636e3,
523 );
524 assert_approx_eq!(p_act, p_exp, rtol <= 1e-1);
525 assert_approx_eq!(v_act, v_exp, rtol <= 1e-1);
526 }
527
528 #[test]
531 fn test_propagate_with_offset_interval() {
532 let s0_orbit = initial_state();
533 let epoch = s0_orbit.time();
534 let j2 = J2Propagator::new(s0_orbit);
535
536 let dt = TimeDelta::from_minutes(40);
537 let offset = TimeDelta::from_minutes(20);
538
539 let full = Interval::new(epoch, epoch + dt);
541 let traj_full = j2.propagate(full).unwrap();
542 let s_full = traj_full.states().into_iter().last().unwrap();
543
544 let offset_interval = Interval::new(epoch + offset, epoch + dt);
546 let traj_offset = j2.propagate(offset_interval).unwrap();
547 let s_offset = traj_offset.states().into_iter().last().unwrap();
548
549 assert_approx_eq!(s_full.position(), s_offset.position(), rtol <= 1e-6);
551 assert_approx_eq!(s_full.velocity(), s_offset.velocity(), rtol <= 1e-6);
552
553 assert_eq!(traj_offset.start_time(), epoch + offset);
555 }
556
557 #[test]
559 fn test_state_at_matches_propagate() {
560 let s0_orbit = initial_state();
561 let epoch = s0_orbit.time();
562 let j2 = J2Propagator::new(s0_orbit);
563
564 let target = epoch + TimeDelta::from_minutes(40);
565 let state = j2.state_at(target).unwrap();
566
567 let interval = Interval::new(epoch, target);
568 let traj = j2.propagate(interval).unwrap();
569 let last = traj.states().into_iter().last().unwrap();
570
571 assert_approx_eq!(state.position(), last.position(), rtol <= 1e-6);
572 assert_approx_eq!(state.velocity(), last.velocity(), rtol <= 1e-6);
573 }
574
575 #[test]
576 fn test_propagate_to() {
577 let s0_orbit = initial_state();
578 let epoch = s0_orbit.time();
579 let j2 = J2Propagator::new(s0_orbit);
580
581 let dt = TimeDelta::from_minutes(40);
582 let interval = Interval::new(epoch, epoch + dt);
583 let times: Vec<_> = interval.step_by(TimeDelta::from_minutes(10)).collect();
584
585 let traj = j2.propagate_to(times.clone()).unwrap();
586 let states = traj.states();
587
588 assert_eq!(states.len(), times.len());
590
591 assert_approx_eq!(states[0].position(), s0_orbit.position(), rtol <= 1e-10);
593
594 let last_time = *times.last().unwrap();
596 let expected = j2.state_at(last_time).unwrap();
597 assert_approx_eq!(
598 states.last().unwrap().position(),
599 expected.position(),
600 rtol <= 1e-6
601 );
602 }
603
604 #[test]
607 fn test_propagate_to_with_offset_times() {
608 let s0_orbit = initial_state();
609 let epoch = s0_orbit.time();
610 let j2 = J2Propagator::new(s0_orbit);
611
612 let start = epoch + TimeDelta::from_minutes(20);
613 let end = start + TimeDelta::from_minutes(20);
614 let interval = Interval::new(start, end);
615 let times: Vec<_> = interval.step_by(TimeDelta::from_minutes(5)).collect();
616
617 let traj = j2.propagate_to(times.clone()).unwrap();
618 let states = traj.states();
619
620 assert_eq!(states.len(), times.len());
621
622 let expected_first = j2.state_at(times[0]).unwrap();
624 assert_approx_eq!(
625 states[0].position(),
626 expected_first.position(),
627 rtol <= 1e-6
628 );
629
630 let expected_last = j2.state_at(*times.last().unwrap()).unwrap();
632 assert_approx_eq!(
633 states.last().unwrap().position(),
634 expected_last.position(),
635 rtol <= 1e-6
636 );
637 }
638
639 #[test]
640 fn test_propagate_to_too_few_times() {
641 let s0_orbit = initial_state();
642 let j2 = J2Propagator::new(s0_orbit);
643
644 let result = j2.propagate_to(vec![]);
646 assert!(result.is_err());
647
648 let result = j2.propagate_to(vec![s0_orbit.time()]);
650 assert!(result.is_err());
651 }
652}