from __future__ import annotations
import math
import numpy as np
from benchmarks.gpu_comparison.implementations import astrojax_kernels
from benchmarks.gpu_comparison.implementations.jax_utils import shard_across_devices
from benchmarks.gpu_comparison.tasks.base import BatchConfig, BatchTask
_ALL_CONFIGS = [
BatchConfig(name="brahe-rust-rayon", dtype="f64", backend="rust"),
BatchConfig(name="astrojax-cpu", dtype="f64", backend="astrojax-cpu"),
BatchConfig(name="astrojax-gpu", dtype="f32", backend="astrojax-gpu"),
BatchConfig(name="astrojax-multigpu", dtype="f32", backend="astrojax-multigpu"),
]
_LADDER = [1, 100, 1_000, 10_000, 100_000, 1_000_000]
class GcrfToItrfStateTask(BatchTask):
name = "frames.gcrf_to_itrf"
module = "frames"
description = "Transform N (epoch, state6) pairs from GCRF to ITRF"
configs = _ALL_CONFIGS
def batch_sizes(self) -> list[int]:
return _LADDER
def generate_inputs(self, batch_size: int, seed: int) -> dict:
rng = np.random.default_rng(seed)
R_EARTH = 6378137.0
GM_EARTH = 3.986004418e14
base_mjd = 60310.0 mjds = base_mjd + rng.uniform(0.0, 5 * 365.25, batch_size)
a = R_EARTH + rng.uniform(400e3, 1500e3, batch_size)
v = np.sqrt(GM_EARTH / a)
nu = rng.uniform(0.0, 2 * np.pi, batch_size)
states = np.empty((batch_size, 6), dtype=np.float64)
states[:, 0] = a * np.cos(nu)
states[:, 1] = a * np.sin(nu)
states[:, 2] = 0.0
states[:, 3] = -v * np.sin(nu)
states[:, 4] = v * np.cos(nu)
states[:, 5] = 0.0
return {"mjd_utc": mjds.tolist(), "state_gcrf": states.tolist()}
def _jnp_dtype(dtype: str):
import jax.numpy as jnp
return jnp.float32 if dtype == "f32" else jnp.float64
def _build_batched_epoch_from_mjd(mjd_utc_list, dtype_str: str):
import jax.numpy as jnp
from astrojax import Epoch
jdtype = jnp.float64 if dtype_str == "f64" else jnp.float32
mjd_arr = jnp.asarray(mjd_utc_list, dtype=jdtype)
jd_full = mjd_arr + 2400000.5
_jd = jnp.floor(jd_full).astype(jnp.int32)
_seconds = ((jd_full - jnp.floor(jd_full)) * 86400.0).astype(jdtype)
_kahan_c = jnp.zeros_like(_seconds)
return Epoch._from_internal(_jd, _seconds, _kahan_c)
def _build_gcrf_to_itrf(task, batch_size, dtype, seed, devices):
import jax
import jax.numpy as jnp
from astrojax.frames import state_gcrf_to_itrf
from benchmarks.gpu_comparison.config import BRAHE_EOP_FILE
from benchmarks.gpu_comparison.data_alignment import load_eop_for_astrojax
eop = load_eop_for_astrojax(BRAHE_EOP_FILE)
params = task.generate_inputs(batch_size, seed)
states = jnp.array(params["state_gcrf"], dtype=_jnp_dtype(dtype))
batched_epoch = _build_batched_epoch_from_mjd(params["mjd_utc"], dtype)
fn = jax.vmap(state_gcrf_to_itrf, in_axes=(None, 0, 0))
if len(devices) == 1 and hasattr(devices[0], "device_kind"):
placed_states = jax.device_put(states, devices[0])
placed_epochs = jax.device_put(batched_epoch, devices[0])
compiled = jax.jit(fn, device=devices[0])
return (lambda _: compiled(eop, placed_epochs, placed_states)), {}
elif len(devices) == 1:
return (lambda _: states), {}
else:
n_dev = len(devices)
batch = states.shape[0]
padded = ((batch + n_dev - 1) // n_dev) * n_dev
if padded > batch:
states = jnp.concatenate(
[states, jnp.zeros((padded - batch, 6), dtype=states.dtype)], axis=0,
)
def _pad_leaf(leaf):
pad_shape = (padded - batch,) + leaf.shape[1:]
return jnp.concatenate([leaf, jnp.zeros(pad_shape, dtype=leaf.dtype)], axis=0)
batched_epoch = jax.tree_util.tree_map(_pad_leaf, batched_epoch)
states_sharded = shard_across_devices(states.reshape(n_dev, -1, 6), devices)
epoch_sharded = jax.tree_util.tree_map(
lambda leaf: shard_across_devices(leaf.reshape(n_dev, -1), devices),
batched_epoch,
)
pmapped = jax.pmap(jax.vmap(state_gcrf_to_itrf, in_axes=(None, 0, 0)),
in_axes=(None, 0, 0))
return (lambda _: pmapped(eop, epoch_sharded, states_sharded)), {}
astrojax_kernels.register("frames.gcrf_to_itrf", _build_gcrf_to_itrf)