from __future__ import annotations
import argparse
import json
from dataclasses import dataclass
from pathlib import Path
try:
from optimum.exporters.onnx import main_export
from transformers import AutoTokenizer
except ModuleNotFoundError as exc:
missing = exc.name or "export dependency"
raise SystemExit(
"Missing Python export dependency "
f"{missing!r}. Install the real-model export stack first:\n"
" python3 -m venv .venv-real-model-export\n"
" . .venv-real-model-export/bin/activate\n"
" pip install -r scripts/requirements-real-model.txt"
) from exc
@dataclass(frozen=True)
class ModelPreset:
name: str
model_id: str
description: str
opset: int = 18
trust_remote_code: bool = False
PRESETS = {
"distilgpt2": ModelPreset(
name="distilgpt2",
model_id="distilgpt2",
description="Tiny GPT-2-family baseline that is practical on CPU",
),
"smollm2-135m-instruct": ModelPreset(
name="smollm2-135m-instruct",
model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
description="Fastest documented CPU-friendly preset",
),
"qwen2.5-0.5b-instruct": ModelPreset(
name="qwen2.5-0.5b-instruct",
model_id="Qwen/Qwen2.5-0.5B-Instruct",
description="Heavier but still practical developer-machine preset",
),
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--preset",
choices=sorted(PRESETS),
help="Named lightweight preset to export",
)
parser.add_argument(
"--model",
help="Explicit Hugging Face model id. Required when --preset is not used.",
)
parser.add_argument(
"--output-dir",
required=True,
help="Directory to write the ONNX bundle into",
)
parser.add_argument(
"--opset",
type=int,
help="ONNX opset. Defaults to the preset opset or 18.",
)
parser.add_argument(
"--dtype",
default="fp32",
choices=["fp32", "fp16", "bf16"],
help="Floating-point export dtype. fp32 is the verified CPU path today.",
)
parser.add_argument(
"--task",
default="text-generation-with-past",
help="Optimum export task. Keep the '-with-past' suffix for KV-cache reuse.",
)
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Device to use during export.",
)
parser.add_argument(
"--attn-implementation",
default="eager",
choices=["eager", "sdpa"],
help=(
"Attention implementation to request during export. "
"Use eager by default because it tends to produce simpler ONNX graphs."
),
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Allow custom modeling code from the model repository.",
)
parser.add_argument(
"--force",
action="store_true",
help="Allow exporting into a non-empty output directory.",
)
return parser.parse_args()
def resolve_model(args: argparse.Namespace) -> tuple[str, int, bool, str]:
if args.preset:
preset = PRESETS[args.preset]
return (
preset.model_id,
args.opset or preset.opset,
args.trust_remote_code or preset.trust_remote_code,
preset.name,
)
if not args.model:
raise SystemExit("pass either --preset or --model")
return args.model, args.opset or 18, args.trust_remote_code, "custom"
def ensure_output_dir(path: Path, force: bool) -> None:
if path.exists():
if not path.is_dir():
raise SystemExit(f"{path} exists and is not a directory")
if any(path.iterdir()) and not force:
raise SystemExit(
f"{path} is not empty; re-run with --force if you want to overwrite into it"
)
path.mkdir(parents=True, exist_ok=True)
def write_manifest(path: Path, payload: dict) -> None:
manifest_path = path / "turboquant_real_model_export.json"
manifest_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
def main() -> None:
args = parse_args()
output_dir = Path(args.output_dir)
ensure_output_dir(output_dir, args.force)
model_id, opset, trust_remote_code, preset_name = resolve_model(args)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
)
tokenizer.save_pretrained(output_dir)
main_export(
model_name_or_path=model_id,
output=output_dir,
task=args.task,
opset=opset,
dtype=args.dtype,
device=args.device,
trust_remote_code=trust_remote_code,
model_kwargs={"attn_implementation": args.attn_implementation},
)
write_manifest(
output_dir,
{
"preset": preset_name,
"model_id": model_id,
"task": args.task,
"dtype": args.dtype,
"device": args.device,
"opset": opset,
"trust_remote_code": trust_remote_code,
"attn_implementation": args.attn_implementation,
"notes": [
"TurboQuant real-model evaluation currently targets ONNX Runtime on CPU.",
"Use fp32 export unless you have validated another dtype on your target runtime.",
],
},
)
print(f"Exported {model_id} to {output_dir}")
if __name__ == "__main__":
main()