from typing import Tuple
import astrora._core as core
import numpy as np
def assert_allclose_with_context(
actual: np.ndarray,
desired: np.ndarray,
rtol: float = 1e-7,
atol: float = 0,
context: str = "",
err_msg: str = "",
) -> None:
try:
np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, err_msg=err_msg)
except AssertionError as e:
if context:
max_diff = np.max(np.abs(actual - desired))
rel_diff = np.max(np.abs((actual - desired) / (desired + 1e-100)))
raise AssertionError(
f"{context} comparison failed:\n"
f" Max absolute difference: {max_diff:.3e}\n"
f" Max relative difference: {rel_diff:.3e}\n"
f" Tolerances: rtol={rtol:.3e}, atol={atol:.3e}\n"
f" Original error: {e}"
)
else:
raise
def assert_states_equal(
state1: "core.CartesianState",
state2: "core.CartesianState",
position_tol: float = 1e-6,
velocity_tol: float = 1e-9,
context: str = "",
) -> None:
pos1 = np.array(state1.position)
pos2 = np.array(state2.position)
vel1 = np.array(state1.velocity)
vel2 = np.array(state2.velocity)
assert_allclose_with_context(
pos1, pos2, atol=position_tol, context=f"{context} position" if context else "position"
)
assert_allclose_with_context(
vel1, vel2, atol=velocity_tol, context=f"{context} velocity" if context else "velocity"
)
def assert_elements_equal(
elem1: "core.OrbitalElements",
elem2: "core.OrbitalElements",
atol_m: float = 1.0,
atol_rad: float = 1e-9,
context: str = "",
) -> None:
assert_allclose_with_context(
np.array([elem1.a]),
np.array([elem2.a]),
atol=atol_m,
context=f"{context} semi-major axis" if context else "semi-major axis",
)
assert_allclose_with_context(
np.array([elem1.e]),
np.array([elem2.e]),
atol=1e-9,
context=f"{context} eccentricity" if context else "eccentricity",
)
angles1 = np.array([elem1.i, elem1.raan, elem1.argp, elem1.nu])
angles2 = np.array([elem2.i, elem2.raan, elem2.argp, elem2.nu])
assert_allclose_with_context(
angles1, angles2, atol=atol_rad, context=f"{context} angles" if context else "angles"
)
def compute_specific_energy(state: "core.CartesianState", gm: float) -> float:
pos = np.array(state.position)
vel = np.array(state.velocity)
r = np.linalg.norm(pos)
v_squared = np.dot(vel, vel)
return 0.5 * v_squared - gm / r
def compute_specific_angular_momentum(state: "core.CartesianState") -> np.ndarray:
r = np.array(state.position)
v = np.array(state.velocity)
return np.cross(r, v)
def assert_energy_conserved(
state_initial: "core.CartesianState",
state_final: "core.CartesianState",
gm: float,
rtol: float = 1e-10,
context: str = "",
) -> None:
e_initial = compute_specific_energy(state_initial, gm)
e_final = compute_specific_energy(state_final, gm)
rel_error = abs((e_final - e_initial) / e_initial)
if rel_error > rtol:
raise AssertionError(
f"{context + ': ' if context else ''}Energy not conserved!\n"
f" Initial energy: {e_initial:.12e} J/kg\n"
f" Final energy: {e_final:.12e} J/kg\n"
f" Difference: {e_final - e_initial:.12e} J/kg\n"
f" Relative error: {rel_error:.12e}\n"
f" Tolerance: {rtol:.12e}"
)
def assert_angular_momentum_conserved(
state_initial: "core.CartesianState",
state_final: "core.CartesianState",
rtol: float = 1e-10,
context: str = "",
) -> None:
h_initial = compute_specific_angular_momentum(state_initial)
h_final = compute_specific_angular_momentum(state_final)
h_mag_initial = np.linalg.norm(h_initial)
h_mag_final = np.linalg.norm(h_final)
rel_error = abs((h_mag_final - h_mag_initial) / h_mag_initial)
if rel_error > rtol:
raise AssertionError(
f"{context + ': ' if context else ''}Angular momentum not conserved!\n"
f" Initial |h|: {h_mag_initial:.12e} m²/s\n"
f" Final |h|: {h_mag_final:.12e} m²/s\n"
f" Difference: {h_mag_final - h_mag_initial:.12e} m²/s\n"
f" Relative error: {rel_error:.12e}\n"
f" Tolerance: {rtol:.12e}"
)
def generate_test_orbits(n_orbits: int = 10, seed: int = 42) -> list:
np.random.seed(seed)
orbits = []
for _ in range(n_orbits):
a = np.random.uniform(6.6e6, 1.0e8) e = np.random.uniform(0.0, 0.95) i = np.random.uniform(0.0, np.pi) raan = np.random.uniform(0.0, 2 * np.pi)
omega = np.random.uniform(0.0, 2 * np.pi)
nu = np.random.uniform(0.0, 2 * np.pi)
orbits.append(
core.OrbitalElements(
semi_major_axis=a,
eccentricity=e,
inclination=i,
raan=raan,
argument_of_periapsis=omega,
true_anomaly=nu,
gm=core.constants.GM_EARTH,
)
)
return orbits
def generate_anomaly_test_cases(
eccentricity: float, n_points: int = 100
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
true_anomalies = np.linspace(0, 2 * np.pi, n_points)
eccentric_anomalies = np.array(
[core.true_to_eccentric_anomaly(nu, eccentricity) for nu in true_anomalies]
)
mean_anomalies = np.array(
[core.eccentric_to_mean_anomaly(E, eccentricity) for E in eccentric_anomalies]
)
return true_anomalies, eccentric_anomalies, mean_anomalies
def compute_position_difference_magnitude(
state1: "core.CartesianState", state2: "core.CartesianState"
) -> float:
pos1 = np.array(state1.position)
pos2 = np.array(state2.position)
return np.linalg.norm(pos1 - pos2)
def compute_velocity_difference_magnitude(
state1: "core.CartesianState", state2: "core.CartesianState"
) -> float:
vel1 = np.array(state1.velocity)
vel2 = np.array(state2.velocity)
return np.linalg.norm(vel1 - vel2)
def classify_orbit_regime(semi_major_axis: float, earth_radius: float = 6378137.0) -> str:
altitude = semi_major_axis - earth_radius
if altitude < 2000000: return "LEO"
elif altitude < 35000000: return "MEO"
elif 35000000 <= altitude <= 37000000: return "GEO"
else:
return "HEO"
def is_circular_orbit(eccentricity: float, tol: float = 1e-6) -> bool:
return eccentricity < tol
def is_equatorial_orbit(inclination: float, tol: float = 1e-6) -> bool:
return abs(inclination) < tol or abs(inclination - np.pi) < tol
def is_polar_orbit(inclination: float, tol: float = 1e-6) -> bool:
return abs(inclination - np.pi / 2) < tol
def format_state_vector(state: "core.CartesianState", name: str = "State") -> str:
pos = np.array(state.position)
vel = np.array(state.velocity)
r = np.linalg.norm(pos)
v = np.linalg.norm(vel)
return (
f"{name}:\n"
f" Position: [{pos[0]:14.6f}, {pos[1]:14.6f}, {pos[2]:14.6f}] m\n"
f" Velocity: [{vel[0]:14.6f}, {vel[1]:14.6f}, {vel[2]:14.6f}] m/s\n"
f" |r| = {r:14.6f} m, |v| = {v:14.6f} m/s"
)
def format_orbital_elements(elem: "core.OrbitalElements", name: str = "Elements") -> str:
return (
f"{name}:\n"
f" a = {elem.a:14.6f} m\n"
f" e = {elem.e:14.12f}\n"
f" i = {np.degrees(elem.i):14.6f}°\n"
f" Ω = {np.degrees(elem.raan):14.6f}°\n"
f" ω = {np.degrees(elem.argp):14.6f}°\n"
f" ν = {np.degrees(elem.nu):14.6f}°"
)
def skip_if_no_ephemerides():
import pytest
try:
from jplephem.spk import SPK
return pytest.mark.skipif(False, reason="")
except (ImportError, FileNotFoundError):
return pytest.mark.skip(reason="JPL ephemerides not available")
def skip_if_slow(reason: str = "Test is too slow for regular runs"):
import pytest
return pytest.mark.slow