import sys, types, time, json, os, subprocess, platform
import torch, math
import numpy as np
from functools import partial
from einops import rearrange, repeat
from torch import nn
class DropPath(nn.Module):
def __init__(self, drop_prob=0.0): super().__init__()
def forward(self, x): return x
class Conv1dWithConstraint(nn.Conv1d):
def __init__(self, *a, max_norm=1.0, **kw): super().__init__(*a, **kw)
class LinearWithConstraint(nn.Linear):
def __init__(self, *a, max_norm=1.0, **kw): super().__init__(*a, **kw)
class EEGModuleMixin:
def __init__(self, n_outputs=None, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, **kwargs):
super().__init__()
self.n_outputs = n_outputs; self.n_chans = n_chans; self.chs_info = chs_info; self.n_times = n_times; self.sfreq = sfreq
self._chs_info = chs_info
bmmb = types.ModuleType('braindecode.models.base'); bmmb.EEGModuleMixin = EEGModuleMixin
sys.modules['braindecode'] = types.ModuleType('braindecode')
sys.modules['braindecode.models'] = types.ModuleType('braindecode.models')
sys.modules['braindecode.models.base'] = bmmb
bmods = types.ModuleType('braindecode.modules'); bmods.DropPath = DropPath
sys.modules['braindecode.modules'] = bmods
bconv = types.ModuleType('braindecode.modules.convolution'); bconv.Conv1dWithConstraint = Conv1dWithConstraint
sys.modules['braindecode.modules.convolution'] = bconv
blin = types.ModuleType('braindecode.modules.linear'); blin.LinearWithConstraint = LinearWithConstraint
sys.modules['braindecode.modules.linear'] = blin
import importlib.util
with open('/Users/Shared/braindecode/braindecode/models/eegpt.py') as f: src = f.read()
src = 'from __future__ import annotations\n' + src
code = compile(src, 'eegpt.py', 'exec')
eegpt_ns = {}
exec(code, eegpt_ns)
EEGPT = eegpt_ns['EEGPT']
WARMUP, REPEATS = 5, 30
CONFIGS = [
(4, 1000, "4ch×1000t"),
(8, 1000, "8ch×1000t"),
(16, 1000, "16ch×1000t"),
(22, 1000, "22ch×1000t"),
(32, 1000, "32ch×1000t"),
(64, 1000, "64ch×1000t"),
(22, 2000, "22ch×2000t"),
(22, 4000, "22ch×4000t"),
]
RUST_BACKENDS = [
("ndarray", "target/release/examples/benchmark_ndarray"),
("accelerate", "target/release/examples/benchmark_accelerate"),
("metal", "target/release/examples/benchmark_metal"),
]
def bench_python(nc, nt):
torch.manual_seed(42)
model = EEGPT(n_outputs=4, n_chans=nc, n_times=nt, sfreq=200,
patch_size=64, patch_stride=32, embed_num=4, embed_dim=512,
depth=2, num_heads=8, mlp_ratio=4.0, drop_prob=0.0, chan_proj_type='none')
model.eval()
x = torch.randn(1, nc, nt)
with torch.no_grad():
for _ in range(WARMUP): _ = model(x)
times = []
with torch.no_grad():
for _ in range(REPEATS):
t0 = time.perf_counter(); _ = model(x); times.append((time.perf_counter()-t0)*1000)
return times
def bench_rust(binary, nc, nt):
if not os.path.exists(binary): return None
try:
r = subprocess.run([binary, str(nc), str(nt), str(WARMUP), str(REPEATS)], capture_output=True, text=True, timeout=120)
if r.returncode != 0: return None
return json.loads(r.stdout)["times_ms"]
except: return None
def main():
os.makedirs("figures", exist_ok=True)
results = {"meta": {"platform": platform.platform(), "machine": platform.machine(),
"torch_version": torch.__version__, "warmup": WARMUP, "repeats": REPEATS}, "benchmarks": []}
print(f"Platform: {platform.platform()}\nPyTorch: {torch.__version__}, depth=2, warmup={WARMUP}, repeats={REPEATS}\n")
for nc, nt, label in CONFIGS:
print(f"── {label} ──")
py = bench_python(nc, nt); pm, ps = np.mean(py), np.std(py)
print(f" Python (PyTorch): {pm:7.2f} ± {ps:.2f} ms")
entry = {"label": label, "n_chans": nc, "n_times": nt,
"python_times_ms": py, "python_mean_ms": float(pm), "python_std_ms": float(ps)}
for bk, binary in RUST_BACKENDS:
rs = bench_rust(binary, nc, nt)
if rs: m,s = np.mean(rs), np.std(rs); sp = pm/m; print(f" Rust ({bk:12s}): {m:7.2f} ± {s:.2f} ms ({sp:.2f}x)")
else: m=s=sp=None; rs=[]
entry[f"rust_{bk}_times_ms"]=rs; entry[f"rust_{bk}_mean_ms"]=float(m) if m else None
entry[f"rust_{bk}_std_ms"]=float(s) if s else None; entry[f"rust_{bk}_speedup"]=float(sp) if sp else None
results["benchmarks"].append(entry); print()
with open("figures/benchmark_results.json", "w") as f: json.dump(results, f, indent=2)
generate_charts(results)
def generate_charts(results):
import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt
B = results["benchmarks"]; labels = [b["label"] for b in B]
colors = {"python":"#4C72B0","ndarray":"#DD8452","accelerate":"#55A868","metal":"#C44E52"}
names = {"python":"Python (PyTorch)","ndarray":"Rust (NdArray)","accelerate":"Rust (Accelerate)","metal":"Rust (Metal GPU)"}
active = [bk for bk in ["ndarray","accelerate","metal"] if any(b.get(f"rust_{bk}_mean_ms") for b in B)]
n_bars = 1+len(active); width = 0.8/n_bars; x = np.arange(len(labels))
fig, ax = plt.subplots(figsize=(14, 6))
ax.bar(x-width*(n_bars-1)/2, [b["python_mean_ms"] for b in B], width, yerr=[b["python_std_ms"] for b in B],
label=names["python"], color=colors["python"], capsize=2, alpha=0.85)
for i,bk in enumerate(active):
ms=[b.get(f"rust_{bk}_mean_ms") or 0 for b in B]; ss=[b.get(f"rust_{bk}_std_ms") or 0 for b in B]
ax.bar(x-width*(n_bars-1)/2+width*(i+1), ms, width, yerr=ss, label=names[bk], color=colors[bk], capsize=2, alpha=0.85)
ax.set_xlabel('Configuration'); ax.set_ylabel('Latency (ms)')
ax.set_title('EEGPT Inference Latency', fontsize=14, fontweight='bold')
ax.set_xticks(x); ax.set_xticklabels(labels, rotation=30, ha='right'); ax.legend(fontsize=10); ax.grid(axis='y', alpha=0.3)
plt.tight_layout(); plt.savefig('figures/inference_latency.png', dpi=150); plt.close(); print("Saved figures/inference_latency.png")
fig, ax = plt.subplots(figsize=(14, 6)); sp_w = 0.8/max(len(active),1)
for i,bk in enumerate(active):
sps=[b.get(f"rust_{bk}_speedup") or 0 for b in B]
bars=ax.bar(x-sp_w*(len(active)-1)/2+sp_w*i, sps, sp_w, color=colors[bk], alpha=0.85, label=names[bk])
for bar,sp in zip(bars,sps):
if sp>0: ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.05, f'{sp:.1f}x', ha='center', va='bottom', fontsize=7, fontweight='bold')
ax.axhline(y=1.0, color='gray', linestyle='--', linewidth=1, label='Parity')
ax.set_xlabel('Configuration'); ax.set_ylabel('Speedup (vs Python)')
ax.set_title('Rust Speedup over Python (PyTorch)', fontsize=14, fontweight='bold')
ax.set_xticks(x); ax.set_xticklabels(labels, rotation=30, ha='right'); ax.legend(fontsize=9); ax.grid(axis='y', alpha=0.3)
plt.tight_layout(); plt.savefig('figures/speedup.png', dpi=150); plt.close(); print("Saved figures/speedup.png")
cb = [b for b in B if b["n_times"]==1000]
if len(cb)>1:
fig, ax = plt.subplots(figsize=(9,5)); ch=[b["n_chans"] for b in cb]
ax.plot(ch, [b["python_mean_ms"] for b in cb], 'o-', color=colors["python"], label=names["python"], linewidth=2, markersize=7)
for bk in active:
la=[b.get(f"rust_{bk}_mean_ms") for b in cb]
if any(v for v in la): ax.plot([c for c,v in zip(ch,la) if v], [v for v in la if v], 's-', color=colors[bk], label=names[bk], linewidth=2, markersize=7)
ax.set_xlabel('Number of Channels'); ax.set_ylabel('Latency (ms)')
ax.set_title('Latency vs Channel Count (T=1000)', fontsize=14, fontweight='bold')
ax.legend(); ax.grid(alpha=0.3); plt.tight_layout(); plt.savefig('figures/channel_scaling.png', dpi=150); plt.close()
print("Saved figures/channel_scaling.png")
tb = [b for b in B if b["n_chans"]==22]
if len(tb)>1:
fig, ax = plt.subplots(figsize=(9,5)); ts=[b["n_times"] for b in tb]
ax.plot(ts, [b["python_mean_ms"] for b in tb], 'o-', color=colors["python"], label=names["python"], linewidth=2, markersize=7)
for bk in active:
la=[b.get(f"rust_{bk}_mean_ms") for b in tb]
if any(v for v in la): ax.plot([t for t,v in zip(ts,la) if v], [v for v in la if v], 's-', color=colors[bk], label=names[bk], linewidth=2, markersize=7)
ax.set_xlabel('Number of Time Samples'); ax.set_ylabel('Latency (ms)')
ax.set_title('Latency vs Signal Length (C=22)', fontsize=14, fontweight='bold')
ax.legend(); ax.grid(alpha=0.3); plt.tight_layout(); plt.savefig('figures/time_scaling.png', dpi=150); plt.close()
print("Saved figures/time_scaling.png")
if __name__ == "__main__": main()