use crate::error::{MIError, Result};
pub(crate) const GEMMASCOPE_WEIGHTS_REPO: &str = "google/gemma-scope-2b-pt-transcoders";
const HF_URL_PREFIX: &str = "hf://";
const MAX_YAML_BYTES: usize = 1024 * 1024;
const MAX_TRANSCODERS: usize = 1024;
pub fn parse_gemmascope_config(yaml: &str) -> Result<Vec<String>> {
if yaml.len() > MAX_YAML_BYTES {
return Err(MIError::Config(format!(
"gemmascope config.yaml is {} bytes; refusing to parse anything \
larger than {MAX_YAML_BYTES} bytes",
yaml.len()
)));
}
let mut npz_paths: Vec<String> = Vec::new();
let mut in_transcoders_block = false;
let expected_repo_prefix = format!("{GEMMASCOPE_WEIGHTS_REPO}/");
for raw_line in yaml.lines() {
let line = raw_line.trim_end();
let stripped = line.trim_start();
if stripped.is_empty() || stripped.starts_with('#') {
continue;
}
if !in_transcoders_block {
if line.trim() == "transcoders:" {
in_transcoders_block = true;
}
continue;
}
if !line.starts_with(' ') && !line.starts_with('\t') {
break;
}
let after_dash = stripped.strip_prefix("- ").ok_or_else(|| {
MIError::Config(format!(
"unexpected non-list line inside transcoders block: {line}"
))
})?;
let unquoted = after_dash
.strip_prefix('"')
.and_then(|s| s.strip_suffix('"'))
.or_else(|| {
after_dash
.strip_prefix('\'')
.and_then(|s| s.strip_suffix('\''))
})
.unwrap_or(after_dash);
let no_scheme = unquoted.strip_prefix(HF_URL_PREFIX).ok_or_else(|| {
MIError::Config(format!(
"transcoders entry missing '{HF_URL_PREFIX}' prefix: {unquoted}"
))
})?;
let relpath = no_scheme
.strip_prefix(&expected_repo_prefix)
.ok_or_else(|| {
MIError::Config(format!(
"transcoders entry points at unexpected repo \
(expected {GEMMASCOPE_WEIGHTS_REPO}): {no_scheme}"
))
})?;
if npz_paths.len() >= MAX_TRANSCODERS {
return Err(MIError::Config(format!(
"gemmascope config.yaml lists more than {MAX_TRANSCODERS} \
transcoders; refusing to continue parsing"
)));
}
npz_paths.push(relpath.to_owned());
}
if !in_transcoders_block {
return Err(MIError::Config(
"no 'transcoders:' key found in gemmascope config.yaml".into(),
));
}
if npz_paths.is_empty() {
return Err(MIError::Config(
"'transcoders:' list is empty in gemmascope config.yaml".into(),
));
}
Ok(npz_paths)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::{GEMMASCOPE_WEIGHTS_REPO, parse_gemmascope_config};
use crate::error::MIError;
const SAMPLE_YAML: &str = r#"# Transcoder Configuration Gemma Scope Transcoders (lowest L0)
model_name: "google/gemma-2-2b"
model_kind: "transcoder_set"
feature_input_hook: "ln2.hook_normalized"
feature_output_hook: 'hook_mlp_out'
transcoders:
- "hf://google/gemma-scope-2b-pt-transcoders/layer_0/width_16k/average_l0_76/params.npz"
- "hf://google/gemma-scope-2b-pt-transcoders/layer_1/width_16k/average_l0_65/params.npz"
- "hf://google/gemma-scope-2b-pt-transcoders/layer_2/width_16k/average_l0_49/params.npz"
"#;
#[test]
fn parses_well_formed_yaml() {
let paths = parse_gemmascope_config(SAMPLE_YAML).unwrap();
assert_eq!(paths.len(), 3);
assert_eq!(paths[0], "layer_0/width_16k/average_l0_76/params.npz");
assert_eq!(paths[1], "layer_1/width_16k/average_l0_65/params.npz");
assert_eq!(paths[2], "layer_2/width_16k/average_l0_49/params.npz");
}
#[test]
fn weights_repo_constant_is_stable() {
assert_eq!(
GEMMASCOPE_WEIGHTS_REPO,
"google/gemma-scope-2b-pt-transcoders"
);
}
#[test]
fn parses_unquoted_entries() {
let yaml = "transcoders:\n - hf://google/gemma-scope-2b-pt-transcoders/layer_0/width_16k/average_l0_76/params.npz\n";
let paths = parse_gemmascope_config(yaml).unwrap();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0], "layer_0/width_16k/average_l0_76/params.npz");
}
#[test]
fn parses_single_quoted_entries() {
let yaml = "transcoders:\n - 'hf://google/gemma-scope-2b-pt-transcoders/layer_0/width_16k/average_l0_76/params.npz'\n";
let paths = parse_gemmascope_config(yaml).unwrap();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0], "layer_0/width_16k/average_l0_76/params.npz");
}
#[test]
fn rejects_missing_transcoders_key() {
let yaml = "model_name: foo\nmodel_kind: bar\n";
let err = parse_gemmascope_config(yaml).unwrap_err();
assert!(
matches!(&err, MIError::Config(msg) if msg.contains("no 'transcoders:'")),
"unexpected error: {err}"
);
}
#[test]
fn rejects_empty_transcoders_list() {
let yaml = "transcoders:\nmodel_name: foo\n";
let err = parse_gemmascope_config(yaml).unwrap_err();
assert!(
matches!(&err, MIError::Config(msg) if msg.contains("empty")),
"unexpected error: {err}"
);
}
#[test]
fn rejects_missing_hf_prefix() {
let yaml = "transcoders:\n - \"google/gemma-scope-2b-pt-transcoders/layer_0/width_16k/average_l0_76/params.npz\"\n";
let err = parse_gemmascope_config(yaml).unwrap_err();
assert!(
matches!(&err, MIError::Config(msg) if msg.contains("hf://")),
"unexpected error: {err}"
);
}
#[test]
fn rejects_wrong_repo() {
let yaml = "transcoders:\n - \"hf://google/some-other-repo/layer_0/width_16k/average_l0_76/params.npz\"\n";
let err = parse_gemmascope_config(yaml).unwrap_err();
assert!(
matches!(&err, MIError::Config(msg) if msg.contains("unexpected repo")),
"unexpected error: {err}"
);
}
#[test]
fn ignores_comments_and_blank_lines() {
let yaml = "# A leading comment\n\ntranscoders:\n # comment inside the block\n - \"hf://google/gemma-scope-2b-pt-transcoders/layer_0/width_16k/average_l0_76/params.npz\"\n\n - \"hf://google/gemma-scope-2b-pt-transcoders/layer_1/width_16k/average_l0_65/params.npz\"\n";
let paths = parse_gemmascope_config(yaml).unwrap();
assert_eq!(paths.len(), 2);
}
#[test]
fn stops_at_next_top_level_key() {
let yaml = "transcoders:\n - \"hf://google/gemma-scope-2b-pt-transcoders/layer_0/width_16k/average_l0_76/params.npz\"\nnext_key: foo\n - \"hf://google/gemma-scope-2b-pt-transcoders/layer_999/width_16k/average_l0_99/params.npz\"\n";
let paths = parse_gemmascope_config(yaml).unwrap();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0], "layer_0/width_16k/average_l0_76/params.npz");
}
#[test]
fn rejects_non_list_line_inside_block() {
let yaml = "transcoders:\n not_a_list_item: foo\n";
let err = parse_gemmascope_config(yaml).unwrap_err();
assert!(
matches!(&err, MIError::Config(msg) if msg.contains("non-list line")),
"unexpected error: {err}"
);
}
#[test]
fn rejects_oversize_yaml() {
let oversize = "x".repeat(super::MAX_YAML_BYTES + 1);
let err = parse_gemmascope_config(&oversize).unwrap_err();
assert!(
matches!(&err, MIError::Config(msg) if msg.contains("refusing to parse")),
"unexpected error: {err}"
);
}
#[test]
fn rejects_too_many_transcoders() {
let mut yaml = String::from("transcoders:\n");
for i in 0..=super::MAX_TRANSCODERS {
yaml.push_str(&format!(
" - \"hf://google/gemma-scope-2b-pt-transcoders/layer_{i}/width_16k/average_l0_1/params.npz\"\n"
));
}
let err = parse_gemmascope_config(&yaml).unwrap_err();
assert!(
matches!(&err, MIError::Config(msg) if msg.contains("more than")),
"unexpected error: {err}"
);
}
#[test]
fn parses_real_curation_yaml_shape() {
let yaml = include_str!("gemmascope_test_fixture.yaml");
let paths = parse_gemmascope_config(yaml).unwrap();
assert_eq!(paths.len(), 26);
assert_eq!(paths[0], "layer_0/width_16k/average_l0_76/params.npz");
assert_eq!(paths[11], "layer_11/width_16k/average_l0_5/params.npz");
assert_eq!(paths[25], "layer_25/width_16k/average_l0_41/params.npz");
}
}