from __future__ import annotations
import os
import statistics
import time
from typing import Iterable
import torch
from transformers import AutoModelForCausalLM
TOKEN_POOL = [
1, 17, 42, 314, 2718, 9001, 27182, 8128, 65535, 12345, 256, 1024,
4096, 16384, 32768, 100, 200, 300, 400, 500, 600, 700, 800, 900,
1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000,
12000, 13000, 14000, 15000, 16000, 17000, 18000, 19000, 20000, 21000,
22000, 23000, 24000, 25000, 26000, 27000, 28000, 29000, 30000, 31000,
32000, 33000, 34000, 35000, 36000, 37000, 38000, 39000, 40000, 41000,
42000, 43000, 44000, 45000, 46000, 47000, 48000, 49000, 50000, 51000,
52000, 53000, 54000, 55000, 56000, 57000, 58000, 59000, 60000, 61000,
62000, 63000, 64000, 65000, 66000, 67000, 68000, 69000, 70000, 71000,
72000, 73000, 74000, 75000, 76000, 77000, 78000, 79000, 80000, 81000,
82000, 83000, 84000, 85000, 86000, 87000, 88000, 89000, 90000, 91000,
92000, 93000, 94000, 95000, 96000, 97000, 98000, 99000, 100000,
101000, 102000, 103000, 104000, 105000, 106000, 107000, 108000,
109000, 110000,
]
def make_ids(batch: int, seq: int, device: str) -> torch.Tensor:
rows = []
for b in range(batch):
off = (b * 7) % len(TOKEN_POOL)
rows.append([TOKEN_POOL[(off + i) % len(TOKEN_POOL)] for i in range(seq)])
return torch.tensor(rows, dtype=torch.long, device=device)
def sync(device: str) -> None:
if device == "mps":
torch.mps.synchronize()
elif device == "cuda":
torch.cuda.synchronize()
def top1_match(a: torch.Tensor, b: torch.Tensor) -> tuple[int, int]:
aa = a.reshape(-1, a.shape[-1]).argmax(dim=-1)
bb = b.reshape(-1, b.shape[-1]).argmax(dim=-1)
return int((aa == bb).sum().item()), int(aa.numel())
def metrics(a: torch.Tensor, b: torch.Tensor) -> tuple[float, float, float, float, int, int]:
af = a.detach().float().cpu().reshape(a.shape[0], -1)
bf = b.detach().float().cpu().reshape(b.shape[0], -1)
diff = (af - bf).abs()
cos = torch.nn.functional.cosine_similarity(af, bf, dim=-1)
t1, total = top1_match(a.detach().cpu(), b.detach().cpu())
return (
float(diff.max().item()),
float(diff.mean().item()),
float(cos.mean().item()),
float(cos.min().item()),
t1,
total,
)
@torch.no_grad()
def run_case(model, device: str, batch: int, seq: int, keep: bool, reps: int):
x = make_ids(batch, seq, device)
kwargs = {"input_ids": x, "use_cache": False}
if keep:
kwargs["logits_to_keep"] = 1
for _ in range(2):
out = model(**kwargs).logits
sync(device)
times = []
for _ in range(reps):
t0 = time.perf_counter()
out = model(**kwargs).logits
sync(device)
times.append((time.perf_counter() - t0) * 1000.0)
return out, times
def emit_row(
backend: str,
mode: str,
batch: int,
seq: int,
shape_ok: bool,
metric_values: tuple[float, float, float, float, int, int],
times: Iterable[float],
status: str,
message: str,
) -> None:
times = list(times)
print(
"prefill,python,{backend},{mode},{batch},{seq},{shape_ok},"
"{max_abs:.6f},{mean_abs:.6f},{cos_mean:.7f},{cos_min:.7f},"
"{top1_match},{top1_total},{min_ms:.1f},{median_ms:.1f},{status},{message}".format(
backend=backend,
mode=mode,
batch=batch,
seq=seq,
shape_ok=str(shape_ok).lower(),
max_abs=metric_values[0],
mean_abs=metric_values[1],
cos_mean=metric_values[2],
cos_min=metric_values[3],
top1_match=metric_values[4],
top1_total=metric_values[5],
min_ms=min(times),
median_ms=statistics.median(times),
status=status,
message=message,
)
)
def main() -> None:
model_dir = os.environ.get("RLX_QWEN3_DIR", "/Users/Shared/rlx/weights/Qwen3-0.6B")
reps = max(1, int(os.environ.get("RLX_QWEN3_MATRIX_REPS", "3")))
target = "mps" if torch.backends.mps.is_available() else "cpu"
print(
"kind,impl,backend,mode,batch,seq,shape_ok,max_abs,mean_abs,cos_mean,cos_min,"
"top1_match,top1_total,min_ms,median_ms,status,message"
)
cpu_model = AutoModelForCausalLM.from_pretrained(
model_dir, local_files_only=True, dtype=torch.float32
).to("cpu").eval()
target_model = cpu_model if target == "cpu" else AutoModelForCausalLM.from_pretrained(
model_dir, local_files_only=True, dtype=torch.float32
).to(target).eval()
for batch in (1, 2, 4):
for seq in (8, 32, 64, 128):
for keep, mode in ((False, "full"), (True, "last")):
ref, _ = run_case(cpu_model, "cpu", batch, seq, keep, 1)
got, times = run_case(target_model, target, batch, seq, keep, reps)
shape_ok = tuple(ref.shape) == tuple(got.shape)
vals = metrics(got, ref)
status = "ok" if shape_ok and vals[4] == vals[5] else "fail"
emit_row(target.upper(), mode, batch, seq, shape_ok, vals, times, status, "cpu_ref")
if __name__ == "__main__":
main()