import json
import os
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = os.environ.get("RLX_GEMMA4_REFERENCE_MODEL", "google/gemma-4-12B")
FIXTURE_DIR = os.environ.get("FIXTURE_DIR") or os.environ.get("RLX_GEMMA4_FIXTURE")
PROMPT = os.environ.get(
"RLX_GEMMA4_REFERENCE_PROMPT",
"The quick brown fox jumps over the lazy dog.",
)
if not FIXTURE_DIR:
print("error: FIXTURE_DIR (or RLX_GEMMA4_FIXTURE) must be set", file=sys.stderr)
sys.exit(2)
os.makedirs(FIXTURE_DIR, exist_ok=True)
print(f"[dump] loading {MODEL_ID}…", file=sys.stderr)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32, device_map="cpu",
).eval()
print(f"[dump] tokenizing prompt: {PROMPT!r}", file=sys.stderr)
ids = tokenizer(PROMPT, return_tensors="pt").input_ids
print(f"[dump] {ids.shape[1]} tokens", file=sys.stderr)
with torch.no_grad():
out = model(ids)
last_logits = out.logits[0, -1].tolist()
print(f"[dump] last-logits len = {len(last_logits)} (vocab)", file=sys.stderr)
path = os.path.join(FIXTURE_DIR, "reference.json")
with open(path, "w") as f:
json.dump(
{
"prompt": PROMPT,
"tokens": ids[0].tolist(),
"logits": last_logits,
},
f,
)
print(f"[dump] wrote {path}", file=sys.stderr)