turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
#!/usr/bin/env python3
"""
Export real-model attention projections to a safetensors trace for the Rust benchmark CLI.

This script captures Q/K/V projections for one attention layer/head from a Hugging Face
causal LM and writes:

  keys            [samples, seq_len, head_dim]
  values          [samples, seq_len, head_dim]
  queries         [samples, query_count, head_dim]
  query_positions [samples, query_count]

The exported tensors are taken from the attention projections before rotary position
encoding so the Q and K vectors live in the same space without model-specific RoPE code.
That makes the trace portable across Gemma, Mistral, Llama-family models and the current
Rust benchmark harness.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Iterable, List

import numpy as np
import torch
from safetensors.numpy import save_file
from transformers import AutoModelForCausalLM, AutoTokenizer


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--model", required=True, help="Hugging Face model id")
    parser.add_argument("--input", required=True, help="Prompt file (.txt or .jsonl)")
    parser.add_argument("--output", required=True, help="Output safetensors path")
    parser.add_argument("--layer", type=int, default=0, help="Attention layer index")
    parser.add_argument("--head", type=int, default=0, help="Attention head index")
    parser.add_argument("--max-samples", type=int, default=8)
    parser.add_argument("--max-length", type=int, default=4096)
    parser.add_argument("--query-stride", type=int, default=64)
    parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda", "mps"])
    parser.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"])
    parser.add_argument("--trust-remote-code", action="store_true")
    return parser.parse_args()


def resolve_device(arg: str) -> torch.device:
    if arg == "cuda" and torch.cuda.is_available():
        return torch.device("cuda")
    if arg == "mps" and torch.backends.mps.is_available():
        return torch.device("mps")
    if arg == "auto":
        if torch.cuda.is_available():
            return torch.device("cuda")
        if torch.backends.mps.is_available():
            return torch.device("mps")
    return torch.device("cpu")


def resolve_dtype(arg: str) -> torch.dtype:
    return {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
    }[arg]


def load_prompts(path: Path, max_samples: int) -> List[str]:
    prompts: List[str] = []
    if path.suffix == ".jsonl":
        with path.open("r", encoding="utf-8") as handle:
            for line in handle:
                if not line.strip():
                    continue
                record = json.loads(line)
                prompt = (
                    record.get("prompt")
                    or record.get("text")
                    or record.get("input")
                    or record.get("question")
                )
                if prompt:
                    prompts.append(prompt)
                if len(prompts) >= max_samples:
                    break
    else:
        with path.open("r", encoding="utf-8") as handle:
            for line in handle:
                prompt = line.strip()
                if prompt:
                    prompts.append(prompt)
                if len(prompts) >= max_samples:
                    break
    if not prompts:
        raise ValueError(f"no prompts found in {path}")
    return prompts


def resolve_layers(model):
    candidates = [
        getattr(getattr(model, "model", None), "layers", None),
        getattr(model, "layers", None),
        getattr(getattr(model, "transformer", None), "h", None),
    ]
    for layers in candidates:
        if layers is not None:
            return layers
    raise ValueError("could not resolve transformer layers on this model")


def resolve_attention_module(layer):
    for name in ("self_attn", "attention", "attn"):
        module = getattr(layer, name, None)
        if module is not None:
            return module
    raise ValueError("could not resolve attention module on target layer")


def num_heads(module, config) -> int:
    return (
        getattr(module, "num_heads", None)
        or getattr(config, "num_attention_heads", None)
        or getattr(config, "n_head", None)
    )


def num_kv_heads(module, config, fallback: int) -> int:
    return (
        getattr(module, "num_key_value_heads", None)
        or getattr(config, "num_key_value_heads", None)
        or getattr(config, "n_kv_heads", None)
        or fallback
    )


def query_positions(seq_len: int, stride: int) -> np.ndarray:
    if seq_len == 0:
        return np.zeros((0,), dtype=np.int64)
    stride = max(1, stride)
    positions = list(range(0, seq_len, stride))
    if positions[-1] != seq_len - 1:
        positions.append(seq_len - 1)
    return np.asarray(positions, dtype=np.int64)


def to_cpu_float32(tensor: torch.Tensor) -> np.ndarray:
    return tensor.detach().to(dtype=torch.float32, device="cpu").numpy()


def main() -> None:
    args = parse_args()
    device = resolve_device(args.device)
    dtype = resolve_dtype(args.dtype)
    prompts = load_prompts(Path(args.input), args.max_samples)

    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=dtype,
        trust_remote_code=args.trust_remote_code,
    ).to(device)
    model.eval()

    layers = resolve_layers(model)
    if args.layer < 0 or args.layer >= len(layers):
        raise ValueError(f"layer {args.layer} is out of range for model with {len(layers)} layers")
    attention = resolve_attention_module(layers[args.layer])

    n_heads = num_heads(attention, model.config)
    n_kv_heads = num_kv_heads(attention, model.config, n_heads)
    if n_heads is None:
        raise ValueError("could not determine attention head count")
    if args.head < 0 or args.head >= n_heads:
        raise ValueError(f"head {args.head} is out of range for model with {n_heads} heads")

    kv_group_size = max(1, n_heads // n_kv_heads)
    kv_head = min(args.head // kv_group_size, n_kv_heads - 1)

    captured_hidden_states = {}

    def hook(_module, hook_args, _hook_output):
        hidden_states = hook_args[0]
        captured_hidden_states["value"] = hidden_states.detach()

    handle = attention.register_forward_hook(hook)

    key_rows: List[np.ndarray] = []
    value_rows: List[np.ndarray] = []
    query_rows: List[np.ndarray] = []
    query_position_rows: List[np.ndarray] = []

    with torch.inference_mode():
        for prompt in prompts:
            encoded = tokenizer(
                prompt,
                truncation=True,
                max_length=args.max_length,
                return_tensors="pt",
            )
            encoded = {name: tensor.to(device) for name, tensor in encoded.items()}
            captured_hidden_states.clear()
            model(**encoded)
            hidden_states = captured_hidden_states.get("value")
            if hidden_states is None:
                raise RuntimeError("failed to capture attention input hidden states")

            query_proj = attention.q_proj(hidden_states)
            key_proj = attention.k_proj(hidden_states)
            value_proj = attention.v_proj(hidden_states)

            batch, seq_len, query_hidden = query_proj.shape
            _, _, key_hidden = key_proj.shape
            if batch != 1:
                raise RuntimeError("only batch size 1 traces are supported")

            head_dim = query_hidden // n_heads
            kv_head_dim = key_hidden // n_kv_heads
            if head_dim != kv_head_dim:
                raise RuntimeError(
                    f"query/key head dim mismatch ({head_dim} vs {kv_head_dim})"
                )

            queries = (
                query_proj.view(batch, seq_len, n_heads, head_dim)
                .permute(0, 2, 1, 3)[0, args.head]
            )
            keys = (
                key_proj.view(batch, seq_len, n_kv_heads, head_dim)
                .permute(0, 2, 1, 3)[0, kv_head]
            )
            values = (
                value_proj.view(batch, seq_len, n_kv_heads, head_dim)
                .permute(0, 2, 1, 3)[0, kv_head]
            )

            positions = query_positions(seq_len, args.query_stride)
            key_rows.append(to_cpu_float32(keys))
            value_rows.append(to_cpu_float32(values))
            query_rows.append(to_cpu_float32(queries[positions]))
            query_position_rows.append(positions)

    handle.remove()

    seq_len = key_rows[0].shape[0]
    head_dim = key_rows[0].shape[1]
    query_count = query_rows[0].shape[0]

    keys = np.stack(key_rows, axis=0)
    values = np.stack(value_rows, axis=0)
    queries = np.stack(query_rows, axis=0)
    query_positions_tensor = np.stack(query_position_rows, axis=0)

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    save_file(
        {
            "keys": keys.astype(np.float32),
            "values": values.astype(np.float32),
            "queries": queries.astype(np.float32),
            "query_positions": query_positions_tensor.astype(np.int64),
        },
        str(output_path),
        metadata={
            "model": args.model,
            "benchmark": output_path.stem,
            "suite": Path(args.input).stem,
            "layer": str(args.layer),
            "head": str(args.head),
            "note": "pre-rope attention projections for TurboQuant benchmarking",
        },
    )

    print(
        json.dumps(
            {
                "output": str(output_path),
                "samples": len(prompts),
                "seq_len": seq_len,
                "query_count": query_count,
                "head_dim": head_dim,
                "head": args.head,
                "kv_head": kv_head,
                "device": str(device),
            },
            indent=2,
        )
    )


if __name__ == "__main__":
    main()