import json
import os
import platform
from pathlib import Path
import numpy as np
import torch
import yaml
from huggingface_hub import hf_hub_download
CURATION_REPO = "mntss/gemma-scope-transcoders"
WEIGHTS_REPO = "google/gemma-scope-2b-pt-transcoders"
TEST_LAYERS = [0, 12, 25]
N_SEEDS_PER_LAYER = 3
TOP_K = 10
def main() -> None:
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")
torch.use_deterministic_algorithms(True)
print(f"GemmaScope PLT reference generation for {WEIGHTS_REPO}")
print(f"Curation: {CURATION_REPO}/config.yaml")
print(f"Test layers: {TEST_LAYERS}, seeds per layer: {N_SEEDS_PER_LAYER}")
print(f"torch {torch.__version__} on {platform.platform()}")
print()
yaml_path = hf_hub_download(CURATION_REPO, "config.yaml")
with open(yaml_path) as f:
curation = yaml.safe_load(f)
transcoders_urls: list[str] = curation["transcoders"]
assert len(transcoders_urls) == 26, (
f"expected 26 entries (one per Gemma 2 2B layer), got {len(transcoders_urls)}"
)
layer_to_relpath: dict[int, str] = {}
for url in transcoders_urls:
prefix = f"hf://{WEIGHTS_REPO}/"
assert url.startswith(prefix), f"unexpected URL prefix: {url}"
relpath = url.removeprefix(prefix)
layer_id = int(relpath.split("/")[0].removeprefix("layer_"))
layer_to_relpath[layer_id] = relpath
results: dict = {
"weights_repo": WEIGHTS_REPO,
"curation_repo": CURATION_REPO,
"methodology": "from-first-principles encoder oracle (no circuit-tracer)",
"schema": "GemmaScopeNpz",
"encoder_formula": "pre = W_enc.T @ residual + b_enc; acts = pre * (pre > threshold)",
"torch_version": torch.__version__,
"platform": platform.platform(),
"d_model": None,
"n_features_per_layer": None,
"test_cases": [],
}
for layer in TEST_LAYERS:
relpath = layer_to_relpath[layer]
npz_path = hf_hub_download(WEIGHTS_REPO, relpath)
params = np.load(npz_path)
w_enc_disk = torch.from_numpy(params["W_enc"]).float()
b_enc = torch.from_numpy(params["b_enc"]).float()
threshold = torch.from_numpy(params["threshold"]).float()
w_dec = params["W_dec"]
b_dec = params["b_dec"]
d_model_disk, n_features = w_enc_disk.shape
w_enc = w_enc_disk.T.contiguous()
d_model = d_model_disk
assert w_enc.shape == (n_features, d_model)
assert b_enc.shape == (n_features,), f"b_enc shape {tuple(b_enc.shape)}"
assert threshold.shape == (n_features,), (
f"threshold shape {tuple(threshold.shape)}"
)
assert w_dec.shape == (n_features, d_model), (
f"W_dec shape {w_dec.shape}"
)
assert b_dec.shape == (d_model,), f"b_dec shape {b_dec.shape}"
assert "W_skip" not in params.files, (
"GemmaScope is a pure JumpReLU transcoder; W_skip should not be present"
)
print(
f"Layer {layer} ({relpath}): W_enc [{d_model}, {n_features}] -> "
f"transposed to [{n_features}, {d_model}], "
f"threshold [{threshold.shape[0]}], W_skip absent"
)
if results["d_model"] is None:
results["d_model"] = d_model
results["n_features_per_layer"] = n_features
else:
assert results["d_model"] == d_model, "d_model drifted across layers"
assert results["n_features_per_layer"] == n_features, (
"n_features drifted across layers"
)
for seed_idx in range(N_SEEDS_PER_LAYER):
seed = seed_idx * 100 + layer
torch.manual_seed(seed)
residual = torch.randn(d_model)
pre_acts = w_enc @ residual + b_enc
mask = (pre_acts > threshold).float()
acts = pre_acts * mask
n_active = int((acts > 0).sum())
top_vals, top_idx = acts.topk(min(TOP_K, n_active))
test_case = {
"layer": layer,
"seed": seed,
"residual": residual.tolist(),
"n_active": n_active,
"top_10": [
{"index": int(idx), "activation": float(val)}
for idx, val in zip(top_idx, top_vals, strict=False)
],
}
results["test_cases"].append(test_case)
top_feat = (
f"L{layer}:{int(top_idx[0])}" if len(top_idx) > 0 else "none"
)
top_act = f"{float(top_vals[0]):.4f}" if len(top_vals) > 0 else "N/A"
print(
f" seed={seed:4d}: {n_active:6d} active / {n_features} features, "
f"top={top_feat} ({top_act})"
)
out_path = Path(__file__).parent / "plt_gemma_reference.json"
with open(out_path, "w") as f:
json.dump(results, f, indent=2)
n_cases = len(results["test_cases"])
file_size = out_path.stat().st_size
print(
f"\nSaved {n_cases} test cases to {out_path} "
f"({file_size / 1024:.1f} KB)"
)
if __name__ == "__main__":
main()