from __future__ import annotations
import os
from pathlib import Path
from typing import Any
def _enable_x64() -> None:
os.environ.setdefault("JAX_ENABLE_X64", "1")
def load_eop_for_astrojax(path: Path) -> Any:
_enable_x64()
from astrojax.eop import load_eop_from_file
return load_eop_from_file(str(path))
def load_space_weather_for_astrojax(path: Path) -> Any:
_enable_x64()
from astrojax.space_weather import load_sw_from_file
return load_sw_from_file(str(path))
def install_global_providers() -> None:
_enable_x64()
from benchmarks.gpu_comparison.config import (
BRAHE_EOP_FILE,
BRAHE_SPACE_WEATHER_FILE,
)
try:
from astrojax.eop import set_global_eop set_global_eop(load_eop_for_astrojax(BRAHE_EOP_FILE))
except (ImportError, AttributeError):
pass
try:
from astrojax.space_weather import set_global_sw set_global_sw(load_space_weather_for_astrojax(BRAHE_SPACE_WEATHER_FILE))
except (ImportError, AttributeError):
pass