import json
import os
import sys
import time
import torch
import torch.nn as nn
from x_transformers import Encoder
HIDDEN = 1152
DEPTH = 8
HEADS = 8
FF_MULT = 4
MAX_SEQ_LEN = 1024
LOW_RANK = 2048
N_OUTPUTS = 20484 N_OUTPUT_TIMESTEPS = 100
N_SUBJECTS = 25
SUBJECT_DROPOUT = 0.1
MODALITIES = {
"text": (3, 3072), "audio": (3, 1024), "video": (3, 1408), }
class SubjectLayersModel(nn.Module):
def __init__(self, in_ch, out_ch, n_subjects=25, subject_dropout=0.1):
super().__init__()
self.n_subjects = n_subjects
self.subject_dropout = subject_dropout
n = n_subjects + (1 if subject_dropout else 0)
self.weights = nn.Parameter(torch.randn(n, in_ch, out_ch) / in_ch**0.5)
self.bias = nn.Parameter(torch.randn(n, out_ch) / in_ch**0.5)
def forward(self, x, subjects=None):
w = self.weights[self.n_subjects]
out = torch.einsum("bct,cd->bdt", x, w)
out += self.bias[self.n_subjects].view(1, -1, 1)
return out
class FmriEncoderModel(nn.Module):
def __init__(self):
super().__init__()
n_mod = len(MODALITIES)
self.projectors = nn.ModuleDict()
for name, (n_layers, feat_dim) in MODALITIES.items():
in_dim = n_layers * feat_dim
out_dim = HIDDEN // n_mod
self.projectors[name] = nn.Linear(in_dim, out_dim)
self.time_pos_embed = nn.Parameter(torch.randn(1, MAX_SEQ_LEN, HIDDEN))
self.encoder = Encoder(
dim=HIDDEN,
depth=DEPTH,
heads=HEADS,
ff_mult=FF_MULT,
use_scalenorm=True,
rotary_pos_emb=True,
scale_residual=True,
attn_dim_head=HIDDEN // HEADS,
)
self.low_rank_head = nn.Linear(HIDDEN, LOW_RANK, bias=False)
self.predictor = SubjectLayersModel(LOW_RANK, N_OUTPUTS, N_SUBJECTS, SUBJECT_DROPOUT)
self.pooler = nn.AdaptiveAvgPool1d(N_OUTPUT_TIMESTEPS)
def forward(self, features):
B = None
T = None
for v in features.values():
B = v.shape[0]
T = v.shape[-1]
break
n_mod = len(MODALITIES)
tensors = []
for name in MODALITIES:
if name in features:
data = features[name] data = data.transpose(1, 2) data = self.projectors[name](data) tensors.append(data)
else:
tensors.append(torch.zeros(B, T, HIDDEN // n_mod, device=data.device))
x = torch.cat(tensors, dim=-1)
x = x + self.time_pos_embed[:, :x.size(1)]
x = self.encoder(x)
x = x.transpose(1, 2)
x = self.low_rank_head(x.transpose(1, 2)).transpose(1, 2)
x = self.predictor(x)
x = self.pooler(x)
return x
def make_input(device, batch_size=1, T=100):
features = {}
for name, (n_layers, feat_dim) in MODALITIES.items():
features[name] = torch.randn(batch_size, n_layers * feat_dim, T, device=device)
return features
def benchmark(model, features, n_warmup=5, n_runs=20, device_name="cpu"):
model.eval()
with torch.no_grad():
for _ in range(n_warmup):
_ = model(features)
if "mps" in device_name:
torch.mps.synchronize()
times = []
for _ in range(n_runs):
if "mps" in device_name:
torch.mps.synchronize()
t0 = time.perf_counter()
out = model(features)
if "mps" in device_name:
torch.mps.synchronize()
t1 = time.perf_counter()
times.append((t1 - t0) * 1000)
return {
"mean_ms": sum(times) / len(times),
"min_ms": min(times),
"max_ms": max(times),
"std_ms": (sum((t - sum(times)/len(times))**2 for t in times) / len(times)) ** 0.5,
"n_runs": n_runs,
"output_shape": list(out.shape),
}
def main():
results = {}
n_warmup = 5
n_runs = 20
print("=== Python CPU (1 thread) ===")
torch.set_num_threads(1)
model = FmriEncoderModel().float()
features = make_input("cpu")
r = benchmark(model, features, n_warmup, n_runs, "cpu")
print(f" Mean: {r['mean_ms']:.1f} ms, Min: {r['min_ms']:.1f} ms, Std: {r['std_ms']:.1f} ms")
results["python_cpu_1thread"] = r
print("=== Python CPU (all threads) ===")
torch.set_num_threads(torch.get_num_threads())
n_cores = os.cpu_count() or 10
torch.set_num_threads(n_cores)
print(f" Using {torch.get_num_threads()} threads")
model = FmriEncoderModel().float()
features = make_input("cpu")
r = benchmark(model, features, n_warmup, n_runs, "cpu")
print(f" Mean: {r['mean_ms']:.1f} ms, Min: {r['min_ms']:.1f} ms, Std: {r['std_ms']:.1f} ms")
results["python_cpu_multithread"] = r
if torch.backends.mps.is_available():
print("=== Python MPS (Apple GPU) ===")
device = torch.device("mps")
model = FmriEncoderModel().float().to(device)
features = make_input(device)
r = benchmark(model, features, n_warmup, n_runs, "mps")
print(f" Mean: {r['mean_ms']:.1f} ms, Min: {r['min_ms']:.1f} ms, Std: {r['std_ms']:.1f} ms")
results["python_mps"] = r
out_path = os.path.join(os.path.dirname(__file__), "results_python.json")
with open(out_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {out_path}")
if __name__ == "__main__":
main()