use crate::gigaam::{GigaAm, GigaAmConfig, Head, TransducerConfig};
use crate::state::{HasStateDict, StateDict};
use super::batch::test_config;
fn rnnt_test_config() -> GigaAmConfig {
let mut cfg = test_config();
let vocabulary: Vec<String> = (0..8).map(|i| format!("p{i}")).collect();
cfg.transducer = Some(TransducerConfig {
pred_hidden: 16,
pred_rnn_layers: 1,
joint_hidden: 16,
num_classes: vocabulary.len() + 1,
max_symbols_per_step: 10,
vocabulary,
sentencepiece: false,
});
cfg
}
fn compose_state_dict(model: &GigaAm) -> StateDict {
let mut sd = model.encoder.subsampling.state_dict("subsampling");
for (i, layer) in model.encoder.layers.iter().enumerate() {
sd.extend(layer.state_dict(&format!("layers.{i}")));
}
match &model.head {
Head::Ctc(h) => sd.extend(h.state_dict("head")),
Head::Rnnt { head, .. } => sd.extend(head.state_dict("head")),
}
sd
}
fn reload(model: &mut GigaAm, sd: &StateDict) {
model.encoder.subsampling.load_state_dict(sd, "subsampling").expect("subsampling reload");
for (i, layer) in model.encoder.layers.iter_mut().enumerate() {
layer.load_state_dict(sd, &format!("layers.{i}")).expect("conformer layer reload");
}
match &mut model.head {
Head::Ctc(h) => h.load_state_dict(sd, "head").expect("ctc head reload"),
Head::Rnnt { head, .. } => head.load_state_dict(sd, "head").expect("rnnt head reload"),
}
}
#[test]
fn gigaam_state_dict_round_trip_ctc() {
let cfg = test_config();
let model = GigaAm::with_random_weights(cfg.clone());
let sd = compose_state_dict(&model);
for key in [
"subsampling.conv1_weight",
"layers.0.ffn1.norm.weight",
"layers.0.ffn1.linear1.weight",
"layers.0.ffn1.linear2.bias",
"layers.0.mhsa.norm.weight",
"layers.0.conv.norm.weight",
"layers.0.ffn2.linear1.weight",
"layers.0.final_norm.weight",
"layers.1.ffn1.norm.weight",
"head.weight",
"head.bias",
] {
assert!(sd.contains_key(key), "missing key: {key}");
}
assert!(!sd.contains_key("head.predictor.embed"), "CTC head must not emit predictor.embed");
assert!(!sd.contains_key("head.joint.enc_w"), "CTC head must not emit joint.enc_w");
let mut empty = GigaAm::with_random_weights(cfg);
reload(&mut empty, &sd);
}
#[test]
fn gigaam_state_dict_round_trip_rnnt() {
let cfg = rnnt_test_config();
let model = GigaAm::with_random_weights(cfg.clone());
let sd = compose_state_dict(&model);
for key in [
"subsampling.conv1_weight",
"layers.0.ffn1.norm.weight",
"layers.1.final_norm.weight",
"head.predictor.embed",
"head.predictor.lstm.0.w_ih",
"head.predictor.lstm.0.b_hh",
"head.joint.enc_w",
"head.joint.enc_b",
"head.joint.pred_w",
"head.joint.out_w",
"head.joint.out_b",
] {
assert!(sd.contains_key(key), "missing key: {key}");
}
assert!(!sd.contains_key("head.weight"), "RN-T head must not emit a bare `head.weight`");
assert!(!sd.contains_key("head.bias"), "RN-T head must not emit a bare `head.bias`");
let mut empty = GigaAm::with_random_weights(cfg);
reload(&mut empty, &sd);
}