import time
import numpy as np
from astrora._core import (
batch_mean_to_eccentric_anomaly,
batch_mean_to_true_anomaly,
batch_propagate_states,
constants,
)
GM_EARTH = constants.GM_EARTH
class TestParallelBatchPerformance:
def test_batch_propagation_small(self, benchmark):
n_orbits = 10
states = self._create_test_orbits(n_orbits)
dt = 3600.0
result = benchmark(batch_propagate_states, states, np.array([dt]), GM_EARTH)
assert result.shape == (n_orbits, 6)
def test_batch_propagation_medium(self, benchmark):
n_orbits = 100
states = self._create_test_orbits(n_orbits)
dt = 3600.0
result = benchmark(batch_propagate_states, states, np.array([dt]), GM_EARTH)
assert result.shape == (n_orbits, 6)
def test_batch_propagation_large(self, benchmark):
n_orbits = 1000
states = self._create_test_orbits(n_orbits)
dt = 3600.0
result = benchmark(batch_propagate_states, states, np.array([dt]), GM_EARTH)
assert result.shape == (n_orbits, 6)
def test_batch_propagation_very_large(self, benchmark):
n_orbits = 5000
states = self._create_test_orbits(n_orbits)
dt = 3600.0
result = benchmark(batch_propagate_states, states, np.array([dt]), GM_EARTH)
assert result.shape == (n_orbits, 6)
def test_batch_anomaly_conversion_small(self, benchmark):
n_orbits = 100
mean_anomalies = np.linspace(0, 2 * np.pi, n_orbits)
eccentricities = np.array([0.5])
result = benchmark(batch_mean_to_eccentric_anomaly, mean_anomalies, eccentricities)
assert len(result) == n_orbits
def test_batch_anomaly_conversion_large(self, benchmark):
n_orbits = 10000
mean_anomalies = np.linspace(0, 2 * np.pi, n_orbits)
eccentricities = np.array([0.5])
result = benchmark(batch_mean_to_eccentric_anomaly, mean_anomalies, eccentricities)
assert len(result) == n_orbits
def test_batch_mean_to_true_large(self, benchmark):
n_orbits = 10000
mean_anomalies = np.linspace(0, 2 * np.pi, n_orbits)
eccentricities = np.array([0.6])
result = benchmark(batch_mean_to_true_anomaly, mean_anomalies, eccentricities)
assert len(result) == n_orbits
@staticmethod
def _create_test_orbits(n: int) -> np.ndarray:
altitudes = np.linspace(7000e3, 12000e3, n)
states = np.zeros((n, 6))
for i, a in enumerate(altitudes):
v = np.sqrt(GM_EARTH / a)
inc = np.radians(90.0 * i / n)
states[i, 0] = a
states[i, 1] = 0.0
states[i, 2] = 0.0
states[i, 3] = 0.0
states[i, 4] = v * np.cos(inc)
states[i, 5] = v * np.sin(inc)
return states
def manual_timing_comparison():
print("\n" + "=" * 80)
print("Rayon Parallelization Performance Demonstration")
print("=" * 80)
sizes = [10, 100, 500, 1000, 2000, 5000]
print("\nBatch Propagation Performance:")
print("-" * 80)
print(f"{'Batch Size':<12} {'Time (ms)':<15} {'Throughput (orbits/sec)':<25}")
print("-" * 80)
for n in sizes:
altitudes = np.linspace(7000e3, 12000e3, n)
states = np.zeros((n, 6))
for i, a in enumerate(altitudes):
v = np.sqrt(GM_EARTH / a)
states[i, 0] = a
states[i, 4] = v
dt = np.array([3600.0])
_ = batch_propagate_states(states, dt, GM_EARTH)
n_runs = 10 if n < 1000 else 5
times = []
for _ in range(n_runs):
start = time.perf_counter()
_ = batch_propagate_states(states, dt, GM_EARTH)
end = time.perf_counter()
times.append(end - start)
avg_time = np.mean(times) * 1000 throughput = n / (avg_time / 1000)
print(f"{n:<12} {avg_time:>12.3f} ms {throughput:>20.1f} orbits/sec")
print("\n" + "=" * 80)
print("Note: Performance scales with number of CPU cores available.")
print("Expected speedup: 2-8x on multi-core systems for large batches.")
print("=" * 80 + "\n")
if __name__ == "__main__":
manual_timing_comparison()