use crate::state::HasStateDict;
use crate::wespeaker::{WeSpeakerConfig, WeSpeakerResNet34};
#[test]
fn state_dict_round_trip() {
let cfg = WeSpeakerConfig::new();
let model = WeSpeakerResNet34::with_zero_weights(cfg.clone());
let sd = model.state_dict("");
for key in [
"conv1.weight",
"bn1.weight",
"bn1.bias",
"bn1.running_mean",
"bn1.running_var",
"layer1.0.conv1.weight",
"layer1.0.bn1.weight",
"layer1.2.conv2.weight",
"layer2.0.downsample.0.weight",
"layer2.0.downsample.1.running_var",
"layer3.0.downsample.0.weight",
"layer4.0.downsample.0.weight",
"layer4.2.bn2.weight",
"seg_1.weight",
"seg_1.bias",
] {
assert!(sd.contains_key(key), "missing key: {key}");
}
let mut empty = WeSpeakerResNet34::with_zero_weights(cfg);
empty.load_state_dict(&sd, "").expect("load round-trip");
}
#[test]
fn config_max_batch_size_with() {
let cfg = WeSpeakerConfig::new().with_max_batch_size(16);
assert_eq!(cfg.max_batch_size, 16);
}
#[test]
fn shapes_match_pyannote_reference() {
let model = WeSpeakerResNet34::with_zero_weights(WeSpeakerConfig::new());
let sd = model.state_dict("");
let seg_w = sd.get("seg_1.weight").unwrap();
let shape: Vec<usize> = seg_w.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
assert_eq!(shape, vec![256, 5120], "seg_1 weight must be (embed_dim=256, stats_dim*2=5120)");
let seg_b = sd.get("seg_1.bias").unwrap();
let bias_shape: Vec<usize> = seg_b.shape().unwrap().iter().map(|d| d.as_const().unwrap()).collect();
assert_eq!(bias_shape, vec![256]);
}