from __future__ import annotations
import sys, types, time, json, os, subprocess, platform
import torch, numpy as np
from torch import nn
from einops import rearrange
from einops.layers.torch import Rearrange
from torch.nn.init import trunc_normal_
from collections import OrderedDict
class DropPath(nn.Module):
def __init__(self, drop_prob=0.0): super().__init__()
def forward(self, x): return x
class MLP(nn.Sequential):
def __init__(self, in_features, hidden_features=None, out_features=None, activation=nn.GELU, drop=0.0, normalize=False):
out_features = out_features or in_features
hf = hidden_features if hidden_features else (in_features, in_features)
layers = []
for before, after in zip((in_features, *hf), (*hf, out_features)):
layers.extend([nn.Linear(before, after), activation()])
layers = layers[:-1]; layers.append(nn.Dropout(p=drop))
super().__init__(*layers)
def rescale_parameter(param, layer_id): param.div_(torch.sqrt(torch.tensor(2.0 * layer_id)))
class EEGModuleMixin:
def __init__(self, n_outputs=None, n_chans=None, chs_info=None, n_times=None, input_window_seconds=None, sfreq=None, **kwargs):
super().__init__(); self.n_outputs=n_outputs; self.n_chans=n_chans; self.chs_info=chs_info; self.n_times=n_times; self.sfreq=sfreq; self._chs_info=chs_info
bmmb = types.ModuleType('braindecode.models.base'); bmmb.EEGModuleMixin = EEGModuleMixin
sys.modules['braindecode'] = types.ModuleType('braindecode'); sys.modules['braindecode.models'] = types.ModuleType('braindecode.models')
sys.modules['braindecode.models.base'] = bmmb
bfunc = types.ModuleType('braindecode.functional'); bfunc.rescale_parameter = rescale_parameter; sys.modules['braindecode.functional'] = bfunc
bmods = types.ModuleType('braindecode.modules'); bmods.DropPath = DropPath; bmods.MLP = MLP; sys.modules['braindecode.modules'] = bmods
with open('/Users/Shared/braindecode/braindecode/models/labram.py') as f: src = f.read()
src = 'from __future__ import annotations\n' + src
ns = {}; exec(compile(src, 'labram.py', 'exec'), ns)
Labram = ns['Labram']
WARMUP, REPEATS = 5, 30
ch8 = ['FP1','FP2','F3','F4','C3','C4','O1','O2']
CONFIGS = [
(8, 800, ch8, "8ch×800t"),
(8, 1600, ch8, "8ch×1600t"),
(8, 3200, ch8, "8ch×3200t"),
]
RUST_BACKENDS = [("ndarray","target/release/examples/benchmark_ndarray"),("accelerate","target/release/examples/benchmark_accelerate"),("metal","target/release/examples/benchmark_metal")]
def bench_python(nc, nt, chn):
torch.manual_seed(42)
chs_info = [{'ch_name': n} for n in chn]
model = Labram(n_outputs=4, n_chans=nc, n_times=nt, sfreq=200, chs_info=chs_info, patch_size=200, embed_dim=200, num_layers=2, num_heads=10, mlp_ratio=4.0, qkv_bias=False, qk_norm=nn.LayerNorm, init_values=0.1, use_abs_pos_emb=True, use_mean_pooling=False, neural_tokenizer=True, learned_patcher=False, drop_prob=0.0, attn_drop_prob=0.0, drop_path_prob=0.0, on_unknown_chs='ignore')
model.eval(); x = torch.randn(1, nc, nt)
with torch.no_grad():
for _ in range(WARMUP): _ = model(x, ch_names=chn)
times = []
with torch.no_grad():
for _ in range(REPEATS): t0=time.perf_counter(); _=model(x,ch_names=chn); times.append((time.perf_counter()-t0)*1000)
return times
def bench_rust(binary, nc, nt):
if not os.path.exists(binary): return None
try:
r = subprocess.run([binary, str(nc), str(nt), str(WARMUP), str(REPEATS)], capture_output=True, text=True, timeout=120)
if r.returncode != 0: return None
return json.loads(r.stdout)["times_ms"]
except: return None
def main():
os.makedirs("figures", exist_ok=True)
results = {"meta": {"platform": platform.platform(), "machine": platform.machine(), "torch_version": torch.__version__}, "benchmarks": []}
print(f"Platform: {platform.platform()}\nPyTorch: {torch.__version__}\n")
for nc, nt, chn, label in CONFIGS:
print(f"── {label} ──")
py = bench_python(nc, nt, chn); pm, ps = np.mean(py), np.std(py)
print(f" Python: {pm:7.2f} ± {ps:.2f} ms")
entry = {"label":label,"n_chans":nc,"n_times":nt,"python_mean_ms":float(pm),"python_std_ms":float(ps),"python_times_ms":py}
for bk, binary in RUST_BACKENDS:
rs = bench_rust(binary, nc, nt)
if rs: m,s=np.mean(rs),np.std(rs); sp=pm/m; print(f" Rust ({bk:12s}): {m:7.2f} ± {s:.2f} ms ({sp:.2f}x)")
else: m=s=sp=None; rs=[]
entry[f"rust_{bk}_mean_ms"]=float(m) if m else None; entry[f"rust_{bk}_std_ms"]=float(s) if s else None
entry[f"rust_{bk}_speedup"]=float(sp) if sp else None; entry[f"rust_{bk}_times_ms"]=rs
results["benchmarks"].append(entry); print()
with open("figures/benchmark_results.json","w") as f: json.dump(results,f,indent=2)
generate_charts(results)
def generate_charts(results):
import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt
B=results["benchmarks"]; labels=[b["label"] for b in B]
colors={"python":"#4C72B0","ndarray":"#DD8452","accelerate":"#55A868","metal":"#C44E52"}
names={"python":"Python (PyTorch)","ndarray":"Rust (NdArray)","accelerate":"Rust (Accelerate)","metal":"Rust (Metal GPU)"}
active=[bk for bk in ["ndarray","accelerate","metal"] if any(b.get(f"rust_{bk}_mean_ms") for b in B)]
n_bars=1+len(active); width=0.8/n_bars; x=np.arange(len(labels))
fig,ax=plt.subplots(figsize=(10,6))
ax.bar(x-width*(n_bars-1)/2,[b["python_mean_ms"] for b in B],width,yerr=[b["python_std_ms"] for b in B],label=names["python"],color=colors["python"],capsize=2,alpha=0.85)
for i,bk in enumerate(active):
ms=[b.get(f"rust_{bk}_mean_ms") or 0 for b in B]; ss=[b.get(f"rust_{bk}_std_ms") or 0 for b in B]
ax.bar(x-width*(n_bars-1)/2+width*(i+1),ms,width,yerr=ss,label=names[bk],color=colors[bk],capsize=2,alpha=0.85)
ax.set_xlabel('Configuration'); ax.set_ylabel('Latency (ms)'); ax.set_title('LaBraM Inference Latency',fontsize=14,fontweight='bold')
ax.set_xticks(x); ax.set_xticklabels(labels,rotation=30,ha='right'); ax.legend(); ax.grid(axis='y',alpha=0.3)
plt.tight_layout(); plt.savefig('figures/inference_latency.png',dpi=150); plt.close(); print("Saved figures/inference_latency.png")
fig,ax=plt.subplots(figsize=(10,6)); sp_w=0.8/max(len(active),1)
for i,bk in enumerate(active):
sps=[b.get(f"rust_{bk}_speedup") or 0 for b in B]
bars=ax.bar(x-sp_w*(len(active)-1)/2+sp_w*i,sps,sp_w,color=colors[bk],alpha=0.85,label=names[bk])
for bar,sp in zip(bars,sps):
if sp>0: ax.text(bar.get_x()+bar.get_width()/2,bar.get_height()+0.02,f'{sp:.2f}x',ha='center',va='bottom',fontsize=9,fontweight='bold')
ax.axhline(y=1.0,color='gray',linestyle='--',linewidth=1,label='Parity')
ax.set_xlabel('Configuration'); ax.set_ylabel('Speedup'); ax.set_title('Rust Speedup over Python',fontsize=14,fontweight='bold')
ax.set_xticks(x); ax.set_xticklabels(labels,rotation=30,ha='right'); ax.legend(); ax.grid(axis='y',alpha=0.3)
plt.tight_layout(); plt.savefig('figures/speedup.png',dpi=150); plt.close(); print("Saved figures/speedup.png")
if __name__ == "__main__": main()