use std::path::Path;
use snafu::ResultExt;
use svod_dtype::DType;
use svod_ir::SInt;
use svod_tensor::{BoundVariable, Tensor};
use crate::blocks::{BatchNormWeights, BlockKind, Conv2dWeights, ResidualStage, remap};
use crate::state::{self, HasStateDict, StateDict, get_tensor};
use super::error::{HubSnafu, PickleSnafu, Result, StateSnafu, TensorSnafu};
use super::pickle;
use super::tstp::tstp_forward;
pub const NUM_MEL_BINS: usize = 80;
pub const EMBED_DIM: usize = 256;
pub const M_CHANNELS: usize = 32;
pub const NUM_BLOCKS: [usize; 4] = [3, 4, 6, 3];
#[derive(Clone, Debug)]
pub struct WeSpeakerConfig {
pub max_batch_size: usize,
}
impl WeSpeakerConfig {
pub fn new() -> Self {
Self { max_batch_size: 1 }
}
pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self {
self.max_batch_size = max_batch_size;
self
}
}
impl Default for WeSpeakerConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct WeSpeakerResNet34 {
pub config: WeSpeakerConfig,
stem_conv: Conv2dWeights,
stem_bn: BatchNormWeights,
stage1: ResidualStage,
stage2: ResidualStage,
stage3: ResidualStage,
stage4: ResidualStage,
seg_1_weight: Tensor,
seg_1_bias: Tensor,
}
fn zeros(shape: &[usize]) -> Tensor {
Tensor::zeros(shape, DType::Float32).expect("zeros for wespeaker placeholder must succeed")
}
impl WeSpeakerResNet34 {
pub fn with_zero_weights(config: WeSpeakerConfig) -> Self {
let stage1 = ResidualStage::empty(BlockKind::Basic, M_CHANNELS, M_CHANNELS, NUM_BLOCKS[0], 1);
let stage2 = ResidualStage::empty(BlockKind::Basic, M_CHANNELS, M_CHANNELS * 2, NUM_BLOCKS[1], 2);
let stage3 = ResidualStage::empty(BlockKind::Basic, M_CHANNELS * 2, M_CHANNELS * 4, NUM_BLOCKS[2], 2);
let stage4 = ResidualStage::empty(BlockKind::Basic, M_CHANNELS * 4, M_CHANNELS * 8, NUM_BLOCKS[3], 2);
let stats_dim = M_CHANNELS * 8 * (NUM_MEL_BINS / 8);
let seg_1_weight = zeros(&[EMBED_DIM, stats_dim * 2]);
let seg_1_bias = zeros(&[EMBED_DIM]);
Self {
config,
stem_conv: Conv2dWeights::empty(M_CHANNELS, 1, 3, 1, 1),
stem_bn: BatchNormWeights::empty(M_CHANNELS),
stage1,
stage2,
stage3,
stage4,
seg_1_weight,
seg_1_bias,
}
}
pub fn from_hub(model_id: &str, config: WeSpeakerConfig) -> Result<Self> {
Self::from_hub_with_revision(model_id, "main", config)
}
pub fn from_hub_with_revision(model_id: &str, revision: &str, config: WeSpeakerConfig) -> Result<Self> {
let api = hf_hub::api::sync::Api::new().context(HubSnafu)?;
let repo =
api.repo(hf_hub::Repo::with_revision(model_id.to_string(), hf_hub::RepoType::Model, revision.to_string()));
let weights_path = repo.get("pytorch_model.bin").context(HubSnafu)?;
Self::from_pytorch_bin(&weights_path, config)
}
pub fn from_pytorch_bin(path: &Path, config: WeSpeakerConfig) -> Result<Self> {
let sd = pickle::load_pyannote_pytorch_bin(path, "resnet.").context(PickleSnafu)?;
let sd: StateDict = sd.into_iter().map(|(k, v)| (rename_shortcut_to_downsample(&k), v)).collect();
Self::from_state_dict(&sd, config)
}
pub fn from_state_dict(sd: &StateDict, config: WeSpeakerConfig) -> Result<Self> {
let sd = remap::fold_batchnorm(sd.clone())?;
let mut model = Self::with_zero_weights(config);
model.stem_conv.load_state_dict(&sd, "conv1").context(StateSnafu)?;
model.stem_bn.load_state_dict(&sd, "bn1").context(StateSnafu)?;
model.stage1.load_state_dict(&sd, "layer1").context(StateSnafu)?;
model.stage2.load_state_dict(&sd, "layer2").context(StateSnafu)?;
model.stage3.load_state_dict(&sd, "layer3").context(StateSnafu)?;
model.stage4.load_state_dict(&sd, "layer4").context(StateSnafu)?;
model.seg_1_weight = get_tensor(&sd, "seg_1.weight").context(StateSnafu)?;
model.seg_1_bias = get_tensor(&sd, "seg_1.bias").context(StateSnafu)?;
Ok(model)
}
pub fn forward(&self, feats: &Tensor, weights: &Tensor, batch: &BoundVariable) -> Result<Tensor> {
let b = batch.as_sint();
let feats = feats.try_shrink([Some((SInt::Const(0), b.clone())), None, None]).context(TensorSnafu)?;
let weights = weights.try_shrink([Some((SInt::Const(0), b)), None]).context(TensorSnafu)?;
let x = feats.try_permute(&[0, 2, 1]).context(TensorSnafu)?;
let x = x.try_unsqueeze(1).context(TensorSnafu)?;
let x = self.stem_bn.forward(&self.stem_conv.forward(&x)?)?.relu().context(TensorSnafu)?;
let x = self.stage1.forward(&x)?;
let x = self.stage2.forward(&x)?;
let x = self.stage3.forward(&x)?;
let x = self.stage4.forward(&x)?;
let stats = tstp_forward(&x, &weights)?;
stats.linear().weight(&self.seg_1_weight).bias(&self.seg_1_bias).call().context(TensorSnafu)
}
}
impl HasStateDict for WeSpeakerResNet34 {
fn state_dict(&self, prefix: &str) -> StateDict {
let mut sd = self.stem_conv.state_dict(&state::prefixed(prefix, "conv1"));
sd.extend(self.stem_bn.state_dict(&state::prefixed(prefix, "bn1")));
sd.extend(self.stage1.state_dict(&state::prefixed(prefix, "layer1")));
sd.extend(self.stage2.state_dict(&state::prefixed(prefix, "layer2")));
sd.extend(self.stage3.state_dict(&state::prefixed(prefix, "layer3")));
sd.extend(self.stage4.state_dict(&state::prefixed(prefix, "layer4")));
sd.insert(state::prefixed(prefix, "seg_1.weight"), self.seg_1_weight.clone());
sd.insert(state::prefixed(prefix, "seg_1.bias"), self.seg_1_bias.clone());
sd
}
fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> std::result::Result<(), state::Error> {
self.stem_conv.load_state_dict(sd, &state::prefixed(prefix, "conv1"))?;
self.stem_bn.load_state_dict(sd, &state::prefixed(prefix, "bn1"))?;
self.stage1.load_state_dict(sd, &state::prefixed(prefix, "layer1"))?;
self.stage2.load_state_dict(sd, &state::prefixed(prefix, "layer2"))?;
self.stage3.load_state_dict(sd, &state::prefixed(prefix, "layer3"))?;
self.stage4.load_state_dict(sd, &state::prefixed(prefix, "layer4"))?;
self.seg_1_weight = get_tensor(sd, &state::prefixed(prefix, "seg_1.weight"))?;
self.seg_1_bias = get_tensor(sd, &state::prefixed(prefix, "seg_1.bias"))?;
Ok(())
}
}
fn rename_shortcut_to_downsample(key: &str) -> String {
key.replace(".shortcut.", ".downsample.")
}