from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
from astropy import units as u
from matplotlib.axes import Axes
from matplotlib.patches import Circle
try:
from .._core import Epoch
from ..bodies import Body
from ..maneuver import Maneuver
from ..twobody import Orbit
except ImportError:
Orbit = None
Body = None
Maneuver = None
Epoch = None
class StaticOrbitPlotter:
def __init__(self, ax: Optional[Axes] = None, plane=None, dark: bool = False):
if ax is None:
fig, ax = plt.subplots(figsize=(8, 8))
self.ax = ax
self._attractor: Optional[Body] = None
self._frame_set = False
self._dark = dark
self.ax.set_aspect("equal")
self.ax.grid(True, alpha=0.3)
if dark:
self.ax.set_facecolor("#1a1a1a")
self.ax.spines["bottom"].set_color("white")
self.ax.spines["top"].set_color("white")
self.ax.spines["left"].set_color("white")
self.ax.spines["right"].set_color("white")
self.ax.tick_params(colors="white")
self.ax.xaxis.label.set_color("white")
self.ax.yaxis.label.set_color("white")
@property
def attractor(self) -> Optional[Body]:
return self._attractor
def set_attractor(self, attractor: Body) -> None:
self._attractor = attractor
if hasattr(attractor, "R"):
radius_m = attractor.R
radius_km = radius_m / 1000.0
circle = Circle(
(0, 0),
radius_km,
color="#3d59ab" if not self._dark else "#4d69bb",
label=attractor.name,
zorder=10,
)
self.ax.add_patch(circle)
def plot(
self,
orbit: "Orbit",
*,
label: Optional[str] = None,
color: Optional[str] = None,
trail: bool = False,
num_points: int = 150,
) -> Tuple[List, plt.Line2D]:
if self._attractor is None:
self.set_attractor(orbit.attractor)
if orbit.ecc < 1.0:
times = np.linspace(0, orbit.period, num_points)
else:
if hasattr(orbit, "period") and orbit.period > 0:
t_range = 3 * orbit.period
else:
t_range = 2 * np.pi * np.sqrt(orbit.p**3 / orbit.attractor.mu)
times = np.linspace(-t_range, t_range, num_points)
positions, _ = orbit.sample(times)
if hasattr(positions, "unit"):
x = positions[:, 0].to(u.km).value
y = positions[:, 1].to(u.km).value
else:
x = positions[:, 0] / 1000.0
y = positions[:, 1] / 1000.0
line_kwargs = {"label": label}
if color is not None:
line_kwargs["color"] = color
if trail:
segments = len(x) // 10
alpha_values = np.linspace(0.2, 1.0, segments)
lines = []
for i in range(segments):
start = i * (len(x) // segments)
end = (i + 1) * (len(x) // segments)
(line,) = self.ax.plot(
x[start:end], y[start:end], alpha=alpha_values[i], **line_kwargs
)
lines.append(line)
line_kwargs.pop("label", None) else:
(line,) = self.ax.plot(x, y, **line_kwargs)
lines = [line]
if hasattr(orbit.r, "unit"):
pos_x = orbit.r[0].to(u.km).value
pos_y = orbit.r[1].to(u.km).value
else:
pos_x = orbit.r[0] / 1000.0
pos_y = orbit.r[1] / 1000.0
pos_color = color if color is not None else lines[0].get_color()
(position,) = self.ax.plot(
pos_x,
pos_y,
marker="o",
markersize=8,
color=pos_color,
markeredgecolor="white" if not self._dark else "black",
markeredgewidth=1.5,
zorder=20,
)
if not self.ax.get_xlabel():
self.ax.set_xlabel("x (km)")
if not self.ax.get_ylabel():
self.ax.set_ylabel("y (km)")
return lines, position
def plot_body_orbit(
self,
body: Body,
epoch: Optional["Epoch"] = None,
*,
label: Optional[str] = None,
color: Optional[str] = None,
trail: bool = False,
) -> Tuple[List, plt.Line2D]:
raise NotImplementedError(
"plot_body_orbit requires ephemeris data integration. "
"Use plot() with an Orbit object created from ephemerides instead."
)
def plot_trajectory(
self,
coordinates: np.ndarray,
*,
label: Optional[str] = None,
color: Optional[str] = None,
trail: bool = False,
) -> Tuple[List, plt.Line2D]:
if self._attractor is None:
raise ValueError("Must set attractor before plotting trajectory")
if hasattr(coordinates, "unit"):
x = coordinates[:, 0].to(u.km).value
y = coordinates[:, 1].to(u.km).value
else:
x = coordinates[:, 0] / 1000.0
y = coordinates[:, 1] / 1000.0
line_kwargs = {"label": label}
if color is not None:
line_kwargs["color"] = color
if trail:
segments = len(x) // 10
alpha_values = np.linspace(0.2, 1.0, segments)
lines = []
for i in range(segments):
start = i * (len(x) // segments)
end = (i + 1) * (len(x) // segments)
(line,) = self.ax.plot(
x[start:end], y[start:end], alpha=alpha_values[i], **line_kwargs
)
lines.append(line)
line_kwargs.pop("label", None)
else:
(line,) = self.ax.plot(x, y, **line_kwargs)
lines = [line]
pos_color = color if color is not None else lines[0].get_color()
(position,) = self.ax.plot(
x[-1],
y[-1],
marker="o",
markersize=8,
color=pos_color,
markeredgecolor="white" if not self._dark else "black",
markeredgewidth=1.5,
zorder=20,
)
return lines, position
def plot_maneuver(
self,
initial_orbit: "Orbit",
maneuver: "Maneuver",
*,
label: Optional[str] = None,
color: Optional[str] = None,
trail: bool = False,
) -> Tuple[List, plt.Line2D]:
if not hasattr(maneuver, "impulses"):
raise ValueError("Maneuver must have impulses attribute")
return self.plot(initial_orbit, label=label or "Maneuver", color=color, trail=trail)
def show(self) -> None:
handles, labels = self.ax.get_legend_handles_labels()
if labels:
self.ax.legend(loc="upper right")
plt.tight_layout()
plt.show()
def savefig(self, filename: str, **kwargs) -> None:
handles, labels = self.ax.get_legend_handles_labels()
if labels:
self.ax.legend(loc="upper right")
plt.tight_layout()
self.ax.figure.savefig(filename, **kwargs)