from typing import List, Optional, Tuple, Union
import numpy as np
try:
import matplotlib.animation as animation
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.patches import Circle
HAS_MATPLOTLIB = True
except ImportError:
HAS_MATPLOTLIB = False
animation = None
try:
import plotly.graph_objects as go
HAS_PLOTLY = True
except ImportError:
HAS_PLOTLY = False
go = None
try:
from .._core import Duration, Epoch
from ..bodies import Body
from ..twobody import Orbit
except ImportError:
Orbit = None
Body = None
Epoch = None
Duration = None
def animate_orbit(
orbit: Union["Orbit", List["Orbit"]],
duration: Optional[float] = None,
num_frames: int = 100,
fps: int = 20,
trail: bool = True,
ax: Optional[Axes] = None,
figsize: Tuple[float, float] = (8, 8),
dark: bool = False,
show_time: bool = True,
save_to: Optional[str] = None,
**kwargs,
) -> animation.FuncAnimation:
if not HAS_MATPLOTLIB:
raise ImportError(
"Matplotlib is required for 2D animations. " "Install it with: pip install matplotlib"
)
if not isinstance(orbit, list):
orbits = [orbit]
else:
orbits = orbit
if duration is None:
if hasattr(orbits[0].period, "value"):
duration = orbits[0].period.value else:
duration = orbits[0].period
times = np.linspace(0, duration, num_frames)
orbit_data = []
for orb in orbits:
positions, velocities = orb.sample(times)
if hasattr(positions, "value"):
positions = positions.value / 1000 else:
positions = positions / 1000
orbit_data.append(positions)
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
ax.set_aspect("equal")
ax.grid(True, alpha=0.3)
ax.set_xlabel("x (km)", fontsize=12)
ax.set_ylabel("y (km)", fontsize=12)
ax.set_title("Orbit Animation", fontsize=14, fontweight="bold")
if dark:
fig.patch.set_facecolor("#1a1a1a")
ax.set_facecolor("#1a1a1a")
ax.spines["bottom"].set_color("white")
ax.spines["top"].set_color("white")
ax.spines["left"].set_color("white")
ax.spines["right"].set_color("white")
ax.tick_params(colors="white")
ax.xaxis.label.set_color("white")
ax.yaxis.label.set_color("white")
ax.title.set_color("white")
text_color = "white"
else:
text_color = "black"
attractor = orbits[0].attractor
if hasattr(attractor, "R"):
if hasattr(attractor.R, "value"):
body_radius = attractor.R.value / 1000 else:
body_radius = attractor.R / 1000
else:
body_radius = 6371
body_circle = Circle((0, 0), body_radius, color="#4169E1", zorder=10, label=attractor.name)
ax.add_patch(body_circle)
max_r = max(np.max(np.abs(data)) for data in orbit_data)
margin = max_r * 0.1
ax.set_xlim(-max_r - margin, max_r + margin)
ax.set_ylim(-max_r - margin, max_r + margin)
orbit_artists = []
colors = plt.cm.tab10(np.linspace(0, 1, len(orbits)))
for i, (orb, data) in enumerate(zip(orbits, orbit_data)):
color = colors[i]
if trail:
(trail_line,) = ax.plot([], [], "-", color=color, alpha=0.6, linewidth=1.5, zorder=5)
else:
trail_line = None
(position_marker,) = ax.plot(
[],
[],
"o",
color=color,
markersize=8,
zorder=20,
label=f"Orbit {i+1}" if len(orbits) > 1 else "Satellite",
)
orbit_artists.append({"trail": trail_line, "marker": position_marker, "data": data})
if show_time:
time_text = ax.text(
0.02,
0.98,
"",
transform=ax.transAxes,
fontsize=12,
verticalalignment="top",
bbox=dict(boxstyle="round", facecolor="wheat" if not dark else "gray", alpha=0.5),
color=text_color,
)
else:
time_text = None
ax.legend(loc="upper right")
def init():
artists = []
for art in orbit_artists:
if art["trail"] is not None:
art["trail"].set_data([], [])
artists.append(art["trail"])
art["marker"].set_data([], [])
artists.append(art["marker"])
if time_text is not None:
time_text.set_text("")
artists.append(time_text)
return artists
def animate_frame(frame):
artists = []
for art in orbit_artists:
data = art["data"]
if art["trail"] is not None:
art["trail"].set_data(data[: frame + 1, 0], data[: frame + 1, 1])
artists.append(art["trail"])
art["marker"].set_data([data[frame, 0]], [data[frame, 1]])
artists.append(art["marker"])
if time_text is not None:
current_time = times[frame]
hours = current_time / 3600
time_text.set_text(f"Time: {hours:.2f} hours")
artists.append(time_text)
return artists
filtered_kwargs = {k: v for k, v in kwargs.items() if k != 'interval'}
anim = animation.FuncAnimation(
fig,
animate_frame,
init_func=init,
frames=num_frames,
interval=1000 / fps,
blit=True,
**filtered_kwargs,
)
if save_to is not None:
print(f"Saving animation to {save_to}...")
if save_to.endswith(".gif"):
anim.save(save_to, writer="pillow", fps=fps)
elif save_to.endswith(".mp4"):
anim.save(save_to, writer="ffmpeg", fps=fps)
elif save_to.endswith(".html"):
anim.save(save_to, writer="html", fps=fps)
else:
anim.save(save_to, fps=fps)
print(f"Animation saved successfully!")
return anim
def animate_orbit_3d(
orbit: Union["Orbit", List["Orbit"]],
duration: Optional[float] = None,
num_frames: int = 100,
fps: int = 20,
trail: bool = True,
dark: bool = False,
show_time: bool = True,
save_to: Optional[str] = None,
include_plotlyjs: str = "cdn",
**kwargs,
) -> go.Figure:
if not HAS_PLOTLY:
raise ImportError(
"Plotly is required for 3D animations. " "Install it with: pip install plotly"
)
if not isinstance(orbit, list):
orbits = [orbit]
else:
orbits = orbit
if duration is None:
if hasattr(orbits[0].period, "value"):
duration = orbits[0].period.value
else:
duration = orbits[0].period
times = np.linspace(0, duration, num_frames)
orbit_data = []
for orb in orbits:
positions, velocities = orb.sample(times)
if hasattr(positions, "value"):
positions = positions.value / 1000
else:
positions = positions / 1000
orbit_data.append(positions)
frames = []
colors = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#FFA07A", "#98D8C8", "#F7DC6F"]
for frame_idx in range(num_frames):
frame_data = []
attractor = orbits[0].attractor
if hasattr(attractor, "R"):
if hasattr(attractor.R, "value"):
body_radius = attractor.R.value / 1000
else:
body_radius = attractor.R / 1000
else:
body_radius = 6371
u = np.linspace(0, 2 * np.pi, 30)
v = np.linspace(0, np.pi, 20)
x_sphere = body_radius * np.outer(np.cos(u), np.sin(v))
y_sphere = body_radius * np.outer(np.sin(u), np.sin(v))
z_sphere = body_radius * np.outer(np.ones(np.size(u)), np.cos(v))
frame_data.append(
go.Surface(
x=x_sphere,
y=y_sphere,
z=z_sphere,
colorscale="Blues",
showscale=False,
opacity=0.7,
name=attractor.name,
)
)
for orbit_idx, data in enumerate(orbit_data):
color = colors[orbit_idx % len(colors)]
if trail:
frame_data.append(
go.Scatter3d(
x=data[:, 0],
y=data[:, 1],
z=data[:, 2],
mode="lines",
line=dict(color=color, width=2),
opacity=0.5,
name=f"Orbit {orbit_idx + 1} trail",
)
)
frame_data.append(
go.Scatter3d(
x=[data[frame_idx, 0]],
y=[data[frame_idx, 1]],
z=[data[frame_idx, 2]],
mode="markers",
marker=dict(size=8, color=color),
name=f"Satellite {orbit_idx + 1}",
)
)
frames.append(go.Frame(data=frame_data, name=str(frame_idx)))
fig = go.Figure(data=frames[0].data, frames=frames)
template = "plotly_dark" if dark else "plotly_white"
max_r = max(np.max(np.abs(data)) for data in orbit_data)
margin = max_r * 0.1
axis_range = [-max_r - margin, max_r + margin]
title_text = "Orbit Animation"
if show_time:
title_text += " (Time: 0.00 hours)"
fig.update_layout(
template=template,
title=title_text,
scene=dict(
xaxis=dict(title="x (km)", range=axis_range),
yaxis=dict(title="y (km)", range=axis_range),
zaxis=dict(title="z (km)", range=axis_range),
aspectmode="cube",
),
updatemenus=[
{
"type": "buttons",
"showactive": False,
"buttons": [
{
"label": "Play",
"method": "animate",
"args": [
None,
{
"frame": {"duration": 1000 / fps, "redraw": True},
"fromcurrent": True,
"mode": "immediate",
"transition": {"duration": 0},
},
],
},
{
"label": "Pause",
"method": "animate",
"args": [
[None],
{
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0},
},
],
},
],
"x": 0.1,
"y": 0,
}
],
sliders=[
{
"active": 0,
"steps": [
{
"args": [
[f.name],
{
"frame": {"duration": 0, "redraw": True},
"mode": "immediate",
"transition": {"duration": 0},
},
],
"label": f"{times[int(f.name)]/3600:.2f}h" if show_time else str(i),
"method": "animate",
}
for i, f in enumerate(frames)
],
"x": 0.1,
"len": 0.9,
"xanchor": "left",
"y": 0,
"yanchor": "top",
}
],
)
if save_to is not None:
print(f"Saving animation to {save_to}...")
fig.write_html(save_to, include_plotlyjs=include_plotlyjs)
if include_plotlyjs == "cdn":
print(f"Animation saved successfully! (Using CDN - requires internet to view)")
elif include_plotlyjs is True:
print(f"Animation saved successfully! (Standalone - works offline)")
else:
print(f"Animation saved successfully!")
return fig
__all__ = [
"animate_orbit",
"animate_orbit_3d",
]