import os
import matplotlib.pyplot as plt
import pandas as pd
def plot_benchmark():
if not os.path.exists("results.csv"):
print("results.csv not found.")
return
df = pd.read_csv("results.csv")
df["StepTimeMs"] = df["Time_Fit_ms"] + df["Time_Prune_ms"]
df["CumulativeTimeMs"] = df.groupby("Strategy")["StepTimeMs"].cumsum()
plt.style.use("ggplot")
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()
strategies = df["Strategy"].unique()
ax = axes[0]
for strategy in strategies:
subset = df[df["Strategy"] == strategy]
ax.plot(
subset["Batch"],
subset["CumulativeTimeMs"] / 1000.0,
marker="o",
label=strategy,
)
ax.set_title("Cumulative Training + Pruning Time")
ax.set_xlabel("Batch Index")
ax.set_ylabel("Time (seconds)")
ax.legend()
ax = axes[1]
for strategy in strategies:
subset = df[df["Strategy"] == strategy]
ax.plot(subset["Batch"], subset["StepTimeMs"], marker="o", label=strategy)
ax.set_title("Step Time (Fit + Prune)")
ax.set_xlabel("Batch Index")
ax.set_ylabel("Time (ms)")
ax.legend()
ax = axes[2]
for strategy in strategies:
subset = df[df["Strategy"] == strategy]
ax.plot(subset["Batch"], subset["MSE"], marker="o", label=strategy)
ax.set_title("Test MSE (Cumulative Data)")
ax.set_xlabel("Batch Index")
ax.set_ylabel("Mean Squared Error")
ax.set_yscale("log") ax.legend()
ax = axes[3]
for strategy in strategies:
subset = df[df["Strategy"] == strategy]
ax.plot(subset["Batch"], subset["Nodes"], marker="o", label=strategy)
ax.set_title("Total Model Nodes")
ax.set_xlabel("Batch Index")
ax.set_ylabel("Count")
ax.legend()
plt.tight_layout()
plt.savefig("benchmark_results.png")
print("Graph saved to benchmark_results.png")
if __name__ == "__main__":
plot_benchmark()