use candle_core::{D, DType, IndexOp, Result, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Module, VarBuilder, conv1d};
use crate::config::speaker_encoder_config::SpeakerEncoderConfig;
#[derive(Debug, Clone)]
pub struct TimeDelayNetBlock {
conv: Conv1d,
padding: usize,
}
impl TimeDelayNetBlock {
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
dilation: usize,
vb: VarBuilder,
) -> Result<Self> {
let padding = (kernel_size - 1) * dilation / 2;
let config = Conv1dConfig {
padding: 0,
dilation,
..Default::default()
};
let conv = conv1d(
in_channels,
out_channels,
kernel_size,
config,
vb.pp("conv"),
)?;
Ok(Self { conv, padding })
}
pub fn load(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
dilation: usize,
vb: VarBuilder,
) -> Result<Self> {
Self::new(in_channels, out_channels, kernel_size, dilation, vb)
}
}
fn reflect_pad_1d(xs: &Tensor, pad_left: usize, pad_right: usize) -> Result<Tensor> {
if pad_left == 0 && pad_right == 0 {
return Ok(xs.clone());
}
let xs = xs.contiguous()?;
let (_batch, _channels, length) = xs.dims3()?;
let device = xs.device();
let mut indices = Vec::with_capacity(pad_left + length + pad_right);
for i in (1..=pad_left).rev() {
indices.push(i as u32);
}
for i in 0..length {
indices.push(i as u32);
}
for i in 0..pad_right {
indices.push((length - 2 - i) as u32);
}
let indices_tensor = Tensor::from_vec(indices, pad_left + length + pad_right, device)?;
xs.index_select(&indices_tensor, 2)
}
impl Module for TimeDelayNetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let padded = reflect_pad_1d(xs, self.padding, self.padding)?;
self.conv.forward(&padded)?.relu()
}
}
#[derive(Debug, Clone)]
pub struct Res2NetBlock {
blocks: Vec<TimeDelayNetBlock>,
scale: usize,
}
impl Res2NetBlock {
pub fn new(
in_channels: usize,
out_channels: usize,
scale: usize,
kernel_size: usize,
dilation: usize,
vb: VarBuilder,
) -> Result<Self> {
let in_channel = in_channels / scale;
let hidden_channel = out_channels / scale;
let blocks = (0..(scale - 1))
.map(|i| {
TimeDelayNetBlock::new(
in_channel,
hidden_channel,
kernel_size,
dilation,
vb.pp(format!("blocks.{}", i)),
)
})
.collect::<Result<Vec<_>>>()?;
Ok(Self { blocks, scale })
}
pub fn load(
in_channels: usize,
out_channels: usize,
scale: usize,
kernel_size: usize,
dilation: usize,
vb: VarBuilder,
) -> Result<Self> {
Self::new(in_channels, out_channels, scale, kernel_size, dilation, vb)
}
}
impl Module for Res2NetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let channels = xs.dim(1)?;
let chunk_size = channels / self.scale;
let mut outputs = Vec::with_capacity(self.scale);
let mut prev_output: Option<Tensor> = None;
for i in 0..self.scale {
let hidden_part = xs.narrow(1, i * chunk_size, chunk_size)?;
let output_part = if i == 0 {
hidden_part
} else if i == 1 {
self.blocks[i - 1].forward(&hidden_part)?
} else {
let combined = (hidden_part + prev_output.as_ref().unwrap())?;
self.blocks[i - 1].forward(&combined)?
};
prev_output = Some(output_part.clone());
outputs.push(output_part);
}
Tensor::cat(&outputs.iter().collect::<Vec<_>>(), 1)
}
}
#[derive(Debug, Clone)]
pub struct SqueezeExcitationBlock {
conv1: Conv1d,
conv2: Conv1d,
}
impl SqueezeExcitationBlock {
pub fn new(
in_channels: usize,
se_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Self> {
let config = Conv1dConfig {
padding: 0,
..Default::default()
};
let conv1 = conv1d(in_channels, se_channels, 1, config, vb.pp("conv1"))?;
let conv2 = conv1d(se_channels, out_channels, 1, config, vb.pp("conv2"))?;
Ok(Self { conv1, conv2 })
}
pub fn load(
in_channels: usize,
se_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Self> {
Self::new(in_channels, se_channels, out_channels, vb)
}
}
impl Module for SqueezeExcitationBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs_mean = xs.mean_keepdim(2)?;
let scale = self.conv1.forward(&xs_mean)?.relu()?;
let scale = candle_nn::ops::sigmoid(&self.conv2.forward(&scale)?)?;
xs.broadcast_mul(&scale)
}
}
#[derive(Debug, Clone)]
pub struct SqueezeExcitationRes2NetBlock {
tdnn1: TimeDelayNetBlock,
res2net_block: Res2NetBlock,
tdnn2: TimeDelayNetBlock,
se_block: SqueezeExcitationBlock,
out_channels: usize,
}
impl SqueezeExcitationRes2NetBlock {
pub fn new(
in_channels: usize,
out_channels: usize,
res2net_scale: usize,
se_channels: usize,
kernel_size: usize,
dilation: usize,
vb: VarBuilder,
) -> Result<Self> {
let tdnn1 = TimeDelayNetBlock::new(in_channels, out_channels, 1, 1, vb.pp("tdnn1"))?;
let res2net_block = Res2NetBlock::new(
out_channels,
out_channels,
res2net_scale,
kernel_size,
dilation,
vb.pp("res2net_block"),
)?;
let tdnn2 = TimeDelayNetBlock::new(out_channels, out_channels, 1, 1, vb.pp("tdnn2"))?;
let se_block = SqueezeExcitationBlock::new(
out_channels,
se_channels,
out_channels,
vb.pp("se_block"),
)?;
Ok(Self {
tdnn1,
res2net_block,
tdnn2,
se_block,
out_channels,
})
}
pub fn load(
in_channels: usize,
out_channels: usize,
res2net_scale: usize,
se_channels: usize,
kernel_size: usize,
dilation: usize,
vb: VarBuilder,
) -> Result<Self> {
Self::new(
in_channels,
out_channels,
res2net_scale,
se_channels,
kernel_size,
dilation,
vb,
)
}
}
impl Module for SqueezeExcitationRes2NetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let hidden = self.tdnn1.forward(xs)?;
let hidden = self.res2net_block.forward(&hidden)?;
let hidden = self.tdnn2.forward(&hidden)?;
let hidden = self.se_block.forward(&hidden)?;
if residual.dim(1)? == self.out_channels {
residual + hidden
} else {
Ok(hidden)
}
}
}
#[derive(Debug, Clone)]
pub struct AttentiveStatisticsPooling {
tdnn: TimeDelayNetBlock,
conv: Conv1d,
eps: f64,
}
fn length_to_mask(lengths: &Tensor, max_len: usize, dtype: DType) -> Result<Tensor> {
let device = lengths.device();
let batch_size = lengths.dim(0)?;
let range = Tensor::arange(0u32, max_len as u32, device)?
.to_dtype(DType::F32)?
.unsqueeze(0)? .broadcast_as((batch_size, max_len))?;
let lengths_f32 = lengths.to_dtype(DType::F32)?.unsqueeze(1)?;
let mask = range.lt(&lengths_f32)?;
mask.to_dtype(dtype)
}
impl AttentiveStatisticsPooling {
pub fn new(channels: usize, attention_channels: usize, vb: VarBuilder) -> Result<Self> {
let tdnn = TimeDelayNetBlock::new(channels * 3, attention_channels, 1, 1, vb.pp("tdnn"))?;
let config = Conv1dConfig {
padding: 0,
..Default::default()
};
let conv = conv1d(attention_channels, channels, 1, config, vb.pp("conv"))?;
Ok(Self {
tdnn,
conv,
eps: 1e-12,
})
}
pub fn load(channels: usize, attention_channels: usize, vb: VarBuilder) -> Result<Self> {
Self::new(channels, attention_channels, vb)
}
fn compute_statistics(&self, xs: &Tensor, weights: &Tensor) -> Result<(Tensor, Tensor)> {
let mean = (xs.broadcast_mul(weights)?).sum(2)?;
let mean_expanded = mean.unsqueeze(2)?; let diff = xs.broadcast_sub(&mean_expanded)?;
let variance = diff.sqr()?.broadcast_mul(weights)?.sum(2)?;
let std = (variance + self.eps)?.sqrt()?;
Ok((mean, std))
}
pub fn forward_with_lengths(&self, xs: &Tensor, lengths: Option<&Tensor>) -> Result<Tensor> {
let seq_length = xs.dim(2)?;
let batch_size = xs.dim(0)?;
let dtype = xs.dtype();
let device = xs.device();
let mask = if let Some(lens) = lengths {
length_to_mask(lens, seq_length, dtype)?
} else {
Tensor::ones((batch_size, seq_length), dtype, device)?
};
let mask = mask.unsqueeze(1)?;
let total = mask.sum_keepdim(2)?;
let normalized_mask = mask.broadcast_div(&total)?;
let (mean, std) = self.compute_statistics(xs, &normalized_mask)?;
let mean_expanded = mean.unsqueeze(2)?.repeat((1, 1, seq_length))?;
let std_expanded = std.unsqueeze(2)?.repeat((1, 1, seq_length))?;
let attention_input = Tensor::cat(&[xs, &mean_expanded, &std_expanded], 1)?;
let attention = self.tdnn.forward(&attention_input)?;
let attention = attention.tanh()?;
let attention = self.conv.forward(&attention)?;
let attention_channels = attention.dim(1)?;
let mask_u32 = mask.to_dtype(candle_core::DType::U32)?;
let mask_expanded = mask_u32.expand((batch_size, attention_channels, seq_length))?;
let neg_inf = Tensor::full(f32::NEG_INFINITY, attention.shape(), device)?;
let attention_f32 = attention.to_dtype(candle_core::DType::F32)?;
let attention = mask_expanded.where_cond(&attention_f32, &neg_inf)?.to_dtype(dtype)?;
let attention = candle_nn::ops::softmax_last_dim(&attention)?;
let (mean_final, std_final) = self.compute_statistics(xs, &attention)?;
let pooled_stats = Tensor::cat(&[&mean_final, &std_final], 1)?;
pooled_stats.unsqueeze(2)
}
}
impl Module for AttentiveStatisticsPooling {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward_with_lengths(xs, None)
}
}
#[derive(Debug, Clone)]
pub struct SpeakerEncoder {
initial_tdnn: TimeDelayNetBlock,
se_res2net_blocks: Vec<SqueezeExcitationRes2NetBlock>,
mfa: TimeDelayNetBlock,
asp: AttentiveStatisticsPooling,
fc: Conv1d,
}
impl SpeakerEncoder {
pub fn new(config: &SpeakerEncoderConfig, vb: VarBuilder) -> Result<Self> {
let initial_tdnn = TimeDelayNetBlock::new(
config.mel_dim,
config.enc_channels[0],
config.enc_kernel_sizes[0],
config.enc_dilations[0],
vb.pp("blocks.0"),
)?;
let mut se_res2net_blocks = Vec::new();
for i in 1..(config.enc_channels.len() - 1) {
let block = SqueezeExcitationRes2NetBlock::new(
config.enc_channels[i - 1],
config.enc_channels[i],
config.enc_res2net_scale,
config.enc_se_channels,
config.enc_kernel_sizes[i],
config.enc_dilations[i],
vb.pp(format!("blocks.{}", i)),
)?;
se_res2net_blocks.push(block);
}
let mfa_in_channels: usize = config.enc_channels[..(config.enc_channels.len() - 1)]
.iter()
.skip(1)
.sum();
let mfa = TimeDelayNetBlock::new(
mfa_in_channels,
config.enc_channels[config.enc_channels.len() - 1],
config.enc_kernel_sizes[config.enc_kernel_sizes.len() - 1],
config.enc_dilations[config.enc_dilations.len() - 1],
vb.pp("mfa"),
)?;
let asp = AttentiveStatisticsPooling::new(
config.enc_channels[config.enc_channels.len() - 1],
config.enc_attention_channels,
vb.pp("asp"),
)?;
let fc_config = Conv1dConfig {
padding: 0,
..Default::default()
};
let fc = conv1d(
config.enc_channels[config.enc_channels.len() - 1] * 2, config.enc_dim,
1,
fc_config,
vb.pp("fc"),
)?;
Ok(Self {
initial_tdnn,
se_res2net_blocks,
mfa,
asp,
fc,
})
}
pub fn load(config: &SpeakerEncoderConfig, vb: VarBuilder) -> Result<Self> {
Self::new(config, vb)
}
}
impl SpeakerEncoder {
pub fn forward_with_lengths(&self, xs: &Tensor, lengths: Option<&Tensor>) -> Result<Tensor> {
let hidden = xs.transpose(1, 2)?.contiguous()?;
tracing::debug!(shape = ?hidden.shape(), "speaker_encoder input after transpose");
let mut hidden = self.initial_tdnn.forward(&hidden)?;
let mut hidden_states = vec![hidden.clone()];
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(h) = hidden.to_dtype(DType::F32)
&& let (Ok(min), Ok(max), Ok(mean), Ok(first5)) = (
h.min(D::Minus1)
.and_then(|t| t.min(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.max(D::Minus1)
.and_then(|t| t.max(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.mean_all().and_then(|t| t.to_scalar::<f32>()),
h.i((0, ..5, 0)).and_then(|t| t.to_vec1::<f32>()),
)
{
tracing::debug!(
shape = ?hidden.shape(),
min = format!("{:.4}", min),
max = format!("{:.4}", max),
mean_val = format!("{:.6}", mean),
first5 = ?first5,
"After blocks[0]"
);
}
for (i, block) in self.se_res2net_blocks.iter().enumerate() {
hidden = block.forward(&hidden)?;
hidden_states.push(hidden.clone());
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(h) = hidden.to_dtype(DType::F32)
&& let (Ok(min), Ok(max), Ok(mean), Ok(first5)) = (
h.min(D::Minus1)
.and_then(|t| t.min(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.max(D::Minus1)
.and_then(|t| t.max(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.mean_all().and_then(|t| t.to_scalar::<f32>()),
h.i((0, ..5, 0)).and_then(|t| t.to_vec1::<f32>()),
)
{
tracing::debug!(
block_idx = i + 1,
shape = ?hidden.shape(),
min = format!("{:.4}", min),
max = format!("{:.4}", max),
mean_val = format!("{:.6}", mean),
first5 = ?first5,
"After SE-Res2Net block"
);
}
}
let mfa_input = Tensor::cat(&hidden_states[1..].iter().collect::<Vec<_>>(), 1)?;
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(h) = mfa_input.to_dtype(DType::F32)
&& let (Ok(min), Ok(max), Ok(mean)) = (
h.min(D::Minus1)
.and_then(|t| t.min(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.max(D::Minus1)
.and_then(|t| t.max(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.mean_all().and_then(|t| t.to_scalar::<f32>()),
)
{
tracing::debug!(
shape = ?mfa_input.shape(),
min = format!("{:.4}", min),
max = format!("{:.4}", max),
mean_val = format!("{:.6}", mean),
"After cat(hidden_states[1:])"
);
}
let hidden = self.mfa.forward(&mfa_input)?;
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(h) = hidden.to_dtype(DType::F32)
&& let (Ok(min), Ok(max), Ok(mean), Ok(first5)) = (
h.min(D::Minus1)
.and_then(|t| t.min(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.max(D::Minus1)
.and_then(|t| t.max(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.mean_all().and_then(|t| t.to_scalar::<f32>()),
h.i((0, ..5, 0)).and_then(|t| t.to_vec1::<f32>()),
)
{
tracing::debug!(
shape = ?hidden.shape(),
min = format!("{:.4}", min),
max = format!("{:.4}", max),
mean_val = format!("{:.6}", mean),
first5 = ?first5,
"After MFA"
);
}
let hidden = self.asp.forward_with_lengths(&hidden, lengths)?;
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(h) = hidden.to_dtype(DType::F32)
&& let (Ok(min), Ok(max), Ok(mean), Ok(first10)) = (
h.min(D::Minus1)
.and_then(|t| t.min(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.max(D::Minus1)
.and_then(|t| t.max(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.mean_all().and_then(|t| t.to_scalar::<f32>()),
h.i((0, ..10, 0)).and_then(|t| t.to_vec1::<f32>()),
)
{
tracing::debug!(
shape = ?hidden.shape(),
min = format!("{:.4}", min),
max = format!("{:.4}", max),
mean_val = format!("{:.6}", mean),
first10 = ?first10,
"After ASP"
);
}
let hidden = self.fc.forward(&hidden)?;
if tracing::enabled!(tracing::Level::DEBUG)
&& let Ok(h) = hidden.to_dtype(DType::F32)
&& let (Ok(min), Ok(max), Ok(mean)) = (
h.min(D::Minus1)
.and_then(|t| t.min(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.max(D::Minus1)
.and_then(|t| t.max(D::Minus1))
.and_then(|t| Ok(t.to_vec1::<f32>()?[0])),
h.mean_all().and_then(|t| t.to_scalar::<f32>()),
)
{
tracing::debug!(
shape = ?hidden.shape(),
min = format!("{:.4}", min),
max = format!("{:.4}", max),
mean_val = format!("{:.6}", mean),
"After FC"
);
}
hidden.squeeze(2)
}
}
impl Module for SpeakerEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward_with_lengths(xs, None)
}
}