turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
#!/usr/bin/env python3
"""
Export a lightweight decoder-only Hugging Face model to an ONNX bundle for TurboQuant's
ONNX Runtime-based real-model benchmark path.

The script keeps the export focused on CPU-friendly `text-generation-with-past` bundles:

- float32 export by default for the currently verified CPU path
- explicit support for past-key-values reuse
- tokenizer/config files copied into the output directory
- preset shortcuts for the currently documented lightweight models
"""

from __future__ import annotations

import argparse
import json
from dataclasses import dataclass
from pathlib import Path

try:
    from optimum.exporters.onnx import main_export
    from transformers import AutoTokenizer
except ModuleNotFoundError as exc:
    missing = exc.name or "export dependency"
    raise SystemExit(
        "Missing Python export dependency "
        f"{missing!r}. Install the real-model export stack first:\n"
        "  python3 -m venv .venv-real-model-export\n"
        "  . .venv-real-model-export/bin/activate\n"
        "  pip install -r scripts/requirements-real-model.txt"
    ) from exc


@dataclass(frozen=True)
class ModelPreset:
    name: str
    model_id: str
    description: str
    opset: int = 18
    trust_remote_code: bool = False


PRESETS = {
    "distilgpt2": ModelPreset(
        name="distilgpt2",
        model_id="distilgpt2",
        description="Tiny GPT-2-family baseline that is practical on CPU",
    ),
    "smollm2-135m-instruct": ModelPreset(
        name="smollm2-135m-instruct",
        model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
        description="Fastest documented CPU-friendly preset",
    ),
    "qwen2.5-0.5b-instruct": ModelPreset(
        name="qwen2.5-0.5b-instruct",
        model_id="Qwen/Qwen2.5-0.5B-Instruct",
        description="Heavier but still practical developer-machine preset",
    ),
}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--preset",
        choices=sorted(PRESETS),
        help="Named lightweight preset to export",
    )
    parser.add_argument(
        "--model",
        help="Explicit Hugging Face model id. Required when --preset is not used.",
    )
    parser.add_argument(
        "--output-dir",
        required=True,
        help="Directory to write the ONNX bundle into",
    )
    parser.add_argument(
        "--opset",
        type=int,
        help="ONNX opset. Defaults to the preset opset or 18.",
    )
    parser.add_argument(
        "--dtype",
        default="fp32",
        choices=["fp32", "fp16", "bf16"],
        help="Floating-point export dtype. fp32 is the verified CPU path today.",
    )
    parser.add_argument(
        "--task",
        default="text-generation-with-past",
        help="Optimum export task. Keep the '-with-past' suffix for KV-cache reuse.",
    )
    parser.add_argument(
        "--device",
        default="cpu",
        choices=["cpu", "cuda"],
        help="Device to use during export.",
    )
    parser.add_argument(
        "--attn-implementation",
        default="eager",
        choices=["eager", "sdpa"],
        help=(
            "Attention implementation to request during export. "
            "Use eager by default because it tends to produce simpler ONNX graphs."
        ),
    )
    parser.add_argument(
        "--trust-remote-code",
        action="store_true",
        help="Allow custom modeling code from the model repository.",
    )
    parser.add_argument(
        "--force",
        action="store_true",
        help="Allow exporting into a non-empty output directory.",
    )
    return parser.parse_args()


def resolve_model(args: argparse.Namespace) -> tuple[str, int, bool, str]:
    if args.preset:
        preset = PRESETS[args.preset]
        return (
            preset.model_id,
            args.opset or preset.opset,
            args.trust_remote_code or preset.trust_remote_code,
            preset.name,
        )
    if not args.model:
        raise SystemExit("pass either --preset or --model")
    return args.model, args.opset or 18, args.trust_remote_code, "custom"


def ensure_output_dir(path: Path, force: bool) -> None:
    if path.exists():
        if not path.is_dir():
            raise SystemExit(f"{path} exists and is not a directory")
        if any(path.iterdir()) and not force:
            raise SystemExit(
                f"{path} is not empty; re-run with --force if you want to overwrite into it"
            )
    path.mkdir(parents=True, exist_ok=True)


def write_manifest(path: Path, payload: dict) -> None:
    manifest_path = path / "turboquant_real_model_export.json"
    manifest_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")


def main() -> None:
    args = parse_args()
    output_dir = Path(args.output_dir)
    ensure_output_dir(output_dir, args.force)

    model_id, opset, trust_remote_code, preset_name = resolve_model(args)

    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        trust_remote_code=trust_remote_code,
    )
    tokenizer.save_pretrained(output_dir)

    main_export(
        model_name_or_path=model_id,
        output=output_dir,
        task=args.task,
        opset=opset,
        dtype=args.dtype,
        device=args.device,
        trust_remote_code=trust_remote_code,
        model_kwargs={"attn_implementation": args.attn_implementation},
    )

    write_manifest(
        output_dir,
        {
            "preset": preset_name,
            "model_id": model_id,
            "task": args.task,
            "dtype": args.dtype,
            "device": args.device,
            "opset": opset,
            "trust_remote_code": trust_remote_code,
            "attn_implementation": args.attn_implementation,
            "notes": [
                "TurboQuant real-model evaluation currently targets ONNX Runtime on CPU.",
                "Use fp32 export unless you have validated another dtype on your target runtime.",
            ],
        },
    )

    print(f"Exported {model_id} to {output_dir}")


if __name__ == "__main__":
    main()