from typing import Optional
import numpy as np
from astropy import units as u
try:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
HAS_PLOTLY = True
except ImportError:
HAS_PLOTLY = False
go = None
try:
from .._core import Epoch
from ..bodies import Body
from ..twobody import Orbit
except ImportError:
Orbit = None
Body = None
Epoch = None
class OrbitPlotter3D:
def __init__(self, dark: bool = False):
if not HAS_PLOTLY:
raise ImportError(
"Plotly is required for 3D plotting. " "Install it with: pip install plotly"
)
self.fig = go.Figure()
self._attractor: Optional[Body] = None
self._dark = dark
self._orbit_count = 0
template = "plotly_dark" if dark else "plotly_white"
self.fig.update_layout(
scene=dict(
xaxis_title="x (km)",
yaxis_title="y (km)",
zaxis_title="z (km)",
aspectmode="data",
xaxis=dict(
showbackground=True,
backgroundcolor=(
"rgba(230, 230, 230, 0.5)" if not dark else "rgba(50, 50, 50, 0.5)"
),
),
yaxis=dict(
showbackground=True,
backgroundcolor=(
"rgba(230, 230, 230, 0.5)" if not dark else "rgba(50, 50, 50, 0.5)"
),
),
zaxis=dict(
showbackground=True,
backgroundcolor=(
"rgba(230, 230, 230, 0.5)" if not dark else "rgba(50, 50, 50, 0.5)"
),
),
),
template=template,
showlegend=True,
legend=dict(
x=0.02,
y=0.98,
xanchor="left",
yanchor="top",
bgcolor="rgba(255, 255, 255, 0.8)" if not dark else "rgba(0, 0, 0, 0.8)",
),
margin=dict(l=0, r=0, t=30, b=0),
)
@property
def attractor(self) -> Optional[Body]:
return self._attractor
def set_attractor(self, attractor: Body) -> None:
self._attractor = attractor
if hasattr(attractor, "R"):
radius_km = attractor.R / 1000.0
u_sphere = np.linspace(0, 2 * np.pi, 30)
v_sphere = np.linspace(0, np.pi, 20)
x_sphere = radius_km * np.outer(np.cos(u_sphere), np.sin(v_sphere))
y_sphere = radius_km * np.outer(np.sin(u_sphere), np.sin(v_sphere))
z_sphere = radius_km * np.outer(np.ones(np.size(u_sphere)), np.cos(v_sphere))
body_colors = {
"Sun": "#FDB813",
"Mercury": "#8C7853",
"Venus": "#FFC649",
"Earth": "#4d69bb",
"Moon": "#999999",
"Mars": "#cd5c5c",
"Jupiter": "#c88b3a",
"Saturn": "#fad5a5",
"Uranus": "#4fd0e7",
"Neptune": "#4166f5",
"Pluto": "#ba8c6e",
}
color = body_colors.get(attractor.name, "#3d59ab")
self.fig.add_trace(
go.Surface(
x=x_sphere,
y=y_sphere,
z=z_sphere,
colorscale=[[0, color], [1, color]],
showscale=False,
name=attractor.name,
hoverinfo="name",
lighting=dict(ambient=0.6, diffuse=0.5, specular=0.3),
opacity=0.9,
)
)
def plot(
self,
orbit: "Orbit",
*,
label: Optional[str] = None,
color: Optional[str] = None,
trail: bool = False,
num_points: int = 150,
) -> None:
if self._attractor is None:
self.set_attractor(orbit.attractor)
if orbit.ecc < 1.0:
times = np.linspace(0, orbit.period, num_points)
else:
if hasattr(orbit, "period") and orbit.period > 0:
t_range = 3 * orbit.period
else:
t_range = 2 * np.pi * np.sqrt(orbit.p**3 / orbit.attractor.mu)
times = np.linspace(-t_range, t_range, num_points)
positions, _ = orbit.sample(times)
if hasattr(positions, "unit"):
x = positions[:, 0].to(u.km).value
y = positions[:, 1].to(u.km).value
z = positions[:, 2].to(u.km).value
else:
x = positions[:, 0] / 1000.0
y = positions[:, 1] / 1000.0
z = positions[:, 2] / 1000.0
default_colors = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
line_color = (
color if color is not None else default_colors[self._orbit_count % len(default_colors)]
)
self._orbit_count += 1
if trail:
segments = 10
for i in range(segments):
start = i * (len(x) // segments)
end = (i + 1) * (len(x) // segments)
opacity = 0.2 + 0.8 * (i / segments)
self.fig.add_trace(
go.Scatter3d(
x=x[start:end],
y=y[start:end],
z=z[start:end],
mode="lines",
line=dict(color=line_color, width=3),
opacity=opacity,
name=label if i == segments - 1 else None,
showlegend=(i == segments - 1 and label is not None),
hoverinfo="text",
text=f'{label or "Orbit"}',
)
)
else:
self.fig.add_trace(
go.Scatter3d(
x=x,
y=y,
z=z,
mode="lines",
line=dict(color=line_color, width=3),
name=label,
hoverinfo="text",
text=label or "Orbit",
)
)
if hasattr(orbit.r, "unit"):
pos_x = orbit.r[0].to(u.km).value
pos_y = orbit.r[1].to(u.km).value
pos_z = orbit.r[2].to(u.km).value
else:
pos_x = orbit.r[0] / 1000.0
pos_y = orbit.r[1] / 1000.0
pos_z = orbit.r[2] / 1000.0
self.fig.add_trace(
go.Scatter3d(
x=[pos_x],
y=[pos_y],
z=[pos_z],
mode="markers",
marker=dict(
size=8,
color=line_color,
line=dict(color="white" if not self._dark else "black", width=2),
),
name=f"{label} position" if label else "Current position",
showlegend=False,
hoverinfo="text",
text=f'{label or "Orbit"} - Current position',
)
)
def plot_trajectory(
self,
coordinates: np.ndarray,
*,
label: Optional[str] = None,
color: Optional[str] = None,
trail: bool = False,
) -> None:
if self._attractor is None:
raise ValueError("Must set attractor before plotting trajectory")
if hasattr(coordinates, "unit"):
x = coordinates[:, 0].to(u.km).value
y = coordinates[:, 1].to(u.km).value
z = coordinates[:, 2].to(u.km).value
else:
x = coordinates[:, 0] / 1000.0
y = coordinates[:, 1] / 1000.0
z = coordinates[:, 2] / 1000.0
default_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"]
line_color = (
color if color is not None else default_colors[self._orbit_count % len(default_colors)]
)
self._orbit_count += 1
if trail:
segments = 10
for i in range(segments):
start = i * (len(x) // segments)
end = (i + 1) * (len(x) // segments)
opacity = 0.2 + 0.8 * (i / segments)
self.fig.add_trace(
go.Scatter3d(
x=x[start:end],
y=y[start:end],
z=z[start:end],
mode="lines",
line=dict(color=line_color, width=3),
opacity=opacity,
name=label if i == segments - 1 else None,
showlegend=(i == segments - 1 and label is not None),
)
)
else:
self.fig.add_trace(
go.Scatter3d(
x=x,
y=y,
z=z,
mode="lines",
line=dict(color=line_color, width=3),
name=label,
)
)
self.fig.add_trace(
go.Scatter3d(
x=[x[-1]],
y=[y[-1]],
z=[z[-1]],
mode="markers",
marker=dict(
size=8,
color=line_color,
line=dict(color="white" if not self._dark else "black", width=2),
),
name=f"{label} final" if label else "Final position",
showlegend=False,
)
)
def show(self) -> None:
self.fig.show()
def savefig(self, filename: str, **kwargs) -> None:
import os
_, ext = os.path.splitext(filename)
if ext == ".html":
self.fig.write_html(filename, **kwargs)
else:
try:
self.fig.write_image(filename, **kwargs)
except ImportError:
raise ImportError(
f"Saving to {ext} format requires kaleido. "
"Install it with: pip install kaleido"
)