from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Iterable, List
import numpy as np
import torch
from safetensors.numpy import save_file
from transformers import AutoModelForCausalLM, AutoTokenizer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model", required=True, help="Hugging Face model id")
parser.add_argument("--input", required=True, help="Prompt file (.txt or .jsonl)")
parser.add_argument("--output", required=True, help="Output safetensors path")
parser.add_argument("--layer", type=int, default=0, help="Attention layer index")
parser.add_argument("--head", type=int, default=0, help="Attention head index")
parser.add_argument("--max-samples", type=int, default=8)
parser.add_argument("--max-length", type=int, default=4096)
parser.add_argument("--query-stride", type=int, default=64)
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda", "mps"])
parser.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"])
parser.add_argument("--trust-remote-code", action="store_true")
return parser.parse_args()
def resolve_device(arg: str) -> torch.device:
if arg == "cuda" and torch.cuda.is_available():
return torch.device("cuda")
if arg == "mps" and torch.backends.mps.is_available():
return torch.device("mps")
if arg == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def resolve_dtype(arg: str) -> torch.dtype:
return {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}[arg]
def load_prompts(path: Path, max_samples: int) -> List[str]:
prompts: List[str] = []
if path.suffix == ".jsonl":
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
record = json.loads(line)
prompt = (
record.get("prompt")
or record.get("text")
or record.get("input")
or record.get("question")
)
if prompt:
prompts.append(prompt)
if len(prompts) >= max_samples:
break
else:
with path.open("r", encoding="utf-8") as handle:
for line in handle:
prompt = line.strip()
if prompt:
prompts.append(prompt)
if len(prompts) >= max_samples:
break
if not prompts:
raise ValueError(f"no prompts found in {path}")
return prompts
def resolve_layers(model):
candidates = [
getattr(getattr(model, "model", None), "layers", None),
getattr(model, "layers", None),
getattr(getattr(model, "transformer", None), "h", None),
]
for layers in candidates:
if layers is not None:
return layers
raise ValueError("could not resolve transformer layers on this model")
def resolve_attention_module(layer):
for name in ("self_attn", "attention", "attn"):
module = getattr(layer, name, None)
if module is not None:
return module
raise ValueError("could not resolve attention module on target layer")
def num_heads(module, config) -> int:
return (
getattr(module, "num_heads", None)
or getattr(config, "num_attention_heads", None)
or getattr(config, "n_head", None)
)
def num_kv_heads(module, config, fallback: int) -> int:
return (
getattr(module, "num_key_value_heads", None)
or getattr(config, "num_key_value_heads", None)
or getattr(config, "n_kv_heads", None)
or fallback
)
def query_positions(seq_len: int, stride: int) -> np.ndarray:
if seq_len == 0:
return np.zeros((0,), dtype=np.int64)
stride = max(1, stride)
positions = list(range(0, seq_len, stride))
if positions[-1] != seq_len - 1:
positions.append(seq_len - 1)
return np.asarray(positions, dtype=np.int64)
def to_cpu_float32(tensor: torch.Tensor) -> np.ndarray:
return tensor.detach().to(dtype=torch.float32, device="cpu").numpy()
def main() -> None:
args = parse_args()
device = resolve_device(args.device)
dtype = resolve_dtype(args.dtype)
prompts = load_prompts(Path(args.input), args.max_samples)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
trust_remote_code=args.trust_remote_code,
).to(device)
model.eval()
layers = resolve_layers(model)
if args.layer < 0 or args.layer >= len(layers):
raise ValueError(f"layer {args.layer} is out of range for model with {len(layers)} layers")
attention = resolve_attention_module(layers[args.layer])
n_heads = num_heads(attention, model.config)
n_kv_heads = num_kv_heads(attention, model.config, n_heads)
if n_heads is None:
raise ValueError("could not determine attention head count")
if args.head < 0 or args.head >= n_heads:
raise ValueError(f"head {args.head} is out of range for model with {n_heads} heads")
kv_group_size = max(1, n_heads // n_kv_heads)
kv_head = min(args.head // kv_group_size, n_kv_heads - 1)
captured_hidden_states = {}
def hook(_module, hook_args, _hook_output):
hidden_states = hook_args[0]
captured_hidden_states["value"] = hidden_states.detach()
handle = attention.register_forward_hook(hook)
key_rows: List[np.ndarray] = []
value_rows: List[np.ndarray] = []
query_rows: List[np.ndarray] = []
query_position_rows: List[np.ndarray] = []
with torch.inference_mode():
for prompt in prompts:
encoded = tokenizer(
prompt,
truncation=True,
max_length=args.max_length,
return_tensors="pt",
)
encoded = {name: tensor.to(device) for name, tensor in encoded.items()}
captured_hidden_states.clear()
model(**encoded)
hidden_states = captured_hidden_states.get("value")
if hidden_states is None:
raise RuntimeError("failed to capture attention input hidden states")
query_proj = attention.q_proj(hidden_states)
key_proj = attention.k_proj(hidden_states)
value_proj = attention.v_proj(hidden_states)
batch, seq_len, query_hidden = query_proj.shape
_, _, key_hidden = key_proj.shape
if batch != 1:
raise RuntimeError("only batch size 1 traces are supported")
head_dim = query_hidden // n_heads
kv_head_dim = key_hidden // n_kv_heads
if head_dim != kv_head_dim:
raise RuntimeError(
f"query/key head dim mismatch ({head_dim} vs {kv_head_dim})"
)
queries = (
query_proj.view(batch, seq_len, n_heads, head_dim)
.permute(0, 2, 1, 3)[0, args.head]
)
keys = (
key_proj.view(batch, seq_len, n_kv_heads, head_dim)
.permute(0, 2, 1, 3)[0, kv_head]
)
values = (
value_proj.view(batch, seq_len, n_kv_heads, head_dim)
.permute(0, 2, 1, 3)[0, kv_head]
)
positions = query_positions(seq_len, args.query_stride)
key_rows.append(to_cpu_float32(keys))
value_rows.append(to_cpu_float32(values))
query_rows.append(to_cpu_float32(queries[positions]))
query_position_rows.append(positions)
handle.remove()
seq_len = key_rows[0].shape[0]
head_dim = key_rows[0].shape[1]
query_count = query_rows[0].shape[0]
keys = np.stack(key_rows, axis=0)
values = np.stack(value_rows, axis=0)
queries = np.stack(query_rows, axis=0)
query_positions_tensor = np.stack(query_position_rows, axis=0)
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
save_file(
{
"keys": keys.astype(np.float32),
"values": values.astype(np.float32),
"queries": queries.astype(np.float32),
"query_positions": query_positions_tensor.astype(np.int64),
},
str(output_path),
metadata={
"model": args.model,
"benchmark": output_path.stem,
"suite": Path(args.input).stem,
"layer": str(args.layer),
"head": str(args.head),
"note": "pre-rope attention projections for TurboQuant benchmarking",
},
)
print(
json.dumps(
{
"output": str(output_path),
"samples": len(prompts),
"seq_len": seq_len,
"query_count": query_count,
"head_dim": head_dim,
"head": args.head,
"kv_head": kv_head,
"device": str(device),
},
indent=2,
)
)
if __name__ == "__main__":
main()