use candle_core::{Module, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
use ferrum_types::{FerrumError, Result};
use tracing::info;
const MEL_FILTERS: &[u8] = include_bytes!("mel_filters_spkenc.bin");
fn reflect_pad_1d(x: &Tensor, pad_left: usize, pad_right: usize) -> candle_core::Result<Tensor> {
if pad_left == 0 && pad_right == 0 {
return Ok(x.clone());
}
let t = x.dim(2)?;
let mut parts: Vec<Tensor> = Vec::new();
let x = x.contiguous()?;
if pad_left > 0 {
let mut left_indices = Vec::with_capacity(pad_left);
for i in (1..=pad_left).rev() {
left_indices.push(i.min(t - 1) as u32);
}
let idx = Tensor::new(left_indices, x.device())?;
parts.push(x.index_select(&idx, 2)?);
}
parts.push(x.clone());
if pad_right > 0 {
let mut right_indices = Vec::with_capacity(pad_right);
for i in 1..=pad_right {
right_indices.push((t - 1).saturating_sub(i) as u32);
}
let idx = Tensor::new(right_indices, x.device())?;
parts.push(x.index_select(&idx, 2)?);
}
Tensor::cat(&parts, 2)
}
struct ReflectConv1d {
conv: Conv1d,
pad_left: usize,
pad_right: usize,
}
impl ReflectConv1d {
fn load(
in_ch: usize,
out_ch: usize,
kernel_size: usize,
dilation: usize,
groups: usize,
vb: VarBuilder,
) -> candle_core::Result<Self> {
let effective_kernel = dilation * (kernel_size - 1) + 1;
let total_pad = effective_kernel - 1;
let pad_left = total_pad / 2;
let pad_right = total_pad - pad_left;
let cfg = Conv1dConfig {
padding: 0,
stride: 1,
dilation,
groups,
cudnn_fwd_algo: None,
};
let w = vb.get((out_ch, in_ch / groups, kernel_size), "weight")?;
let b = vb.get(out_ch, "bias").ok();
Ok(Self {
conv: Conv1d::new(w, b, cfg),
pad_left,
pad_right,
})
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let x = reflect_pad_1d(x, self.pad_left, self.pad_right)?;
self.conv.forward(&x)
}
}
struct TimeDelayNetBlock {
conv: ReflectConv1d,
}
impl TimeDelayNetBlock {
fn load(
in_ch: usize,
out_ch: usize,
kernel_size: usize,
dilation: usize,
vb: VarBuilder,
) -> candle_core::Result<Self> {
let conv = ReflectConv1d::load(in_ch, out_ch, kernel_size, dilation, 1, vb.pp("conv"))?;
Ok(Self { conv })
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
self.conv.forward(x)?.relu()
}
}
struct Res2NetBlock {
scale: usize, chunk_size: usize,
blocks: Vec<TimeDelayNetBlock>, }
impl Res2NetBlock {
fn load(
channels: usize,
kernel_size: usize,
dilation: usize,
scale: usize,
vb: VarBuilder,
) -> candle_core::Result<Self> {
let chunk_size = channels / scale;
let mut blocks = Vec::with_capacity(scale - 1);
for j in 0..(scale - 1) {
let tdnn = TimeDelayNetBlock::load(
chunk_size,
chunk_size,
kernel_size,
dilation,
vb.pp(format!("blocks.{j}")),
)?;
blocks.push(tdnn);
}
Ok(Self {
scale,
chunk_size,
blocks,
})
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let mut outputs: Vec<Tensor> = Vec::with_capacity(self.scale);
let chunk0 = x.narrow(1, 0, self.chunk_size)?;
outputs.push(chunk0);
for i in 1..self.scale {
let chunk_i = x.narrow(1, i * self.chunk_size, self.chunk_size)?;
let input_i = if i == 1 {
chunk_i
} else {
(chunk_i + outputs.last().unwrap())?
};
let out_i = self.blocks[i - 1].forward(&input_i)?;
outputs.push(out_i);
}
Tensor::cat(&outputs, 1)
}
}
struct SqueezeExcitationBlock {
conv1: ReflectConv1d,
conv2: ReflectConv1d,
}
impl SqueezeExcitationBlock {
fn load(channels: usize, se_channels: usize, vb: VarBuilder) -> candle_core::Result<Self> {
let conv1 = ReflectConv1d::load(channels, se_channels, 1, 1, 1, vb.pp("conv1"))?;
let conv2 = ReflectConv1d::load(se_channels, channels, 1, 1, 1, vb.pp("conv2"))?;
Ok(Self { conv1, conv2 })
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let s = x.mean_keepdim(2)?;
let s = self.conv1.forward(&s)?.relu()?;
let s = self.conv2.forward(&s)?;
let s = sigmoid(&s)?;
x.broadcast_mul(&s)
}
}
fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
let ones = x.ones_like()?;
let neg = x.neg()?;
ones.broadcast_div(&(neg.exp()? + 1.0)?)
}
struct SERes2NetBlock {
tdnn1: TimeDelayNetBlock,
res2net_block: Res2NetBlock,
tdnn2: TimeDelayNetBlock,
se_block: SqueezeExcitationBlock,
shortcut: Option<ReflectConv1d>, }
impl SERes2NetBlock {
fn load(
in_ch: usize,
out_ch: usize,
kernel_size: usize,
dilation: usize,
se_channels: usize,
res2net_scale: usize,
vb: VarBuilder,
) -> candle_core::Result<Self> {
let tdnn1 = TimeDelayNetBlock::load(in_ch, out_ch, 1, 1, vb.pp("tdnn1"))?;
let res2net_block = Res2NetBlock::load(
out_ch,
kernel_size,
dilation,
res2net_scale,
vb.pp("res2net_block"),
)?;
let tdnn2 = TimeDelayNetBlock::load(out_ch, out_ch, 1, 1, vb.pp("tdnn2"))?;
let se_block = SqueezeExcitationBlock::load(out_ch, se_channels, vb.pp("se_block"))?;
let shortcut = if in_ch != out_ch {
Some(ReflectConv1d::load(
in_ch,
out_ch,
1,
1,
1,
vb.pp("shortcut.conv"),
)?)
} else {
None
};
Ok(Self {
tdnn1,
res2net_block,
tdnn2,
se_block,
shortcut,
})
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let residual = match &self.shortcut {
Some(sc) => sc.forward(x)?,
None => x.clone(),
};
let out = self.tdnn1.forward(x)?;
let out = self.res2net_block.forward(&out)?;
let out = self.tdnn2.forward(&out)?;
let out = self.se_block.forward(&out)?;
out + residual
}
}
struct AttentiveStatisticsPooling {
tdnn: TimeDelayNetBlock, conv: ReflectConv1d, }
impl AttentiveStatisticsPooling {
fn load(
channels: usize,
attention_channels: usize,
vb: VarBuilder,
) -> candle_core::Result<Self> {
let tdnn = TimeDelayNetBlock::load(channels * 3, attention_channels, 1, 1, vb.pp("tdnn"))?;
let conv = ReflectConv1d::load(attention_channels, channels, 1, 1, 1, vb.pp("conv"))?;
Ok(Self { tdnn, conv })
}
fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
let mean = x.mean_keepdim(2)?; let diff = x.broadcast_sub(&mean)?;
let var = diff.sqr()?.mean_keepdim(2)?;
let std = (var + 1e-5)?.sqrt()?;
let mean_exp = mean.expand(x.dims())?; let std_exp = std.expand(x.dims())?;
let cat = Tensor::cat(&[x, &mean_exp, &std_exp], 1)?;
let attn = self.tdnn.forward(&cat)?; let attn = attn.tanh()?;
let attn = self.conv.forward(&attn)?;
let attn = softmax_dim2(&attn)?;
let weighted = (x * &attn)?;
let w_mean = weighted.sum_keepdim(2)?;
let w_diff = x.broadcast_sub(&w_mean)?;
let w_var = (w_diff.sqr()? * &attn)?.sum_keepdim(2)?;
let w_std = (w_var + 1e-5)?.sqrt()?;
Tensor::cat(&[&w_mean, &w_std], 1)
}
}
fn softmax_dim2(x: &Tensor) -> candle_core::Result<Tensor> {
let max = x.max_keepdim(2)?;
let shifted = x.broadcast_sub(&max)?;
let exp = shifted.exp()?;
let sum = exp.sum_keepdim(2)?;
exp.broadcast_div(&sum)
}
pub struct SpeakerEncoder {
block0: TimeDelayNetBlock,
se_blocks: Vec<SERes2NetBlock>, mfa: TimeDelayNetBlock,
asp: AttentiveStatisticsPooling,
fc: ReflectConv1d,
}
impl SpeakerEncoder {
pub fn load_with_dim(vb: VarBuilder, enc_dim: usize) -> Result<Self> {
info!("Loading ECAPA-TDNN speaker encoder");
let block0 = TimeDelayNetBlock::load(128, 512, 5, 1, vb.pp("blocks.0"))
.map_err(|e| FerrumError::model(format!("speaker_encoder blocks.0: {e}")))?;
let mut se_blocks = Vec::with_capacity(3);
for (i, dilation) in [(1usize, 2usize), (2, 3), (3, 4)] {
let blk = SERes2NetBlock::load(
512, 512, 3, dilation,
128, 8, vb.pp(format!("blocks.{i}")),
)
.map_err(|e| FerrumError::model(format!("speaker_encoder blocks.{i}: {e}")))?;
se_blocks.push(blk);
}
let mfa = TimeDelayNetBlock::load(1536, 1536, 1, 1, vb.pp("mfa"))
.map_err(|e| FerrumError::model(format!("speaker_encoder mfa: {e}")))?;
let asp = AttentiveStatisticsPooling::load(1536, 128, vb.pp("asp"))
.map_err(|e| FerrumError::model(format!("speaker_encoder asp: {e}")))?;
let fc = ReflectConv1d::load(3072, enc_dim, 1, 1, 1, vb.pp("fc"))
.map_err(|e| FerrumError::model(format!("speaker_encoder fc: {e}")))?;
info!(
"Speaker encoder loaded (ECAPA-TDNN, {}-dim output)",
enc_dim
);
Ok(Self {
block0,
se_blocks,
mfa,
asp,
fc,
})
}
pub fn forward(&self, mel: &Tensor) -> Result<Tensor> {
let x = mel
.transpose(1, 2)
.and_then(|t| t.contiguous())
.map_err(|e| FerrumError::model(format!("speaker_encoder transpose: {e}")))?;
let x = self
.block0
.forward(&x)
.map_err(|e| FerrumError::model(format!("speaker_encoder block0: {e}")))?;
let mut se_outputs = Vec::with_capacity(3);
let mut x = x;
for (i, blk) in self.se_blocks.iter().enumerate() {
x = blk
.forward(&x)
.map_err(|e| FerrumError::model(format!("speaker_encoder se_block[{i}]: {e}")))?;
se_outputs.push(x.clone());
}
let mfa_in = Tensor::cat(&se_outputs, 1)
.map_err(|e| FerrumError::model(format!("speaker_encoder mfa cat: {e}")))?;
let mfa_out = self
.mfa
.forward(&mfa_in)
.map_err(|e| FerrumError::model(format!("speaker_encoder mfa: {e}")))?;
let asp_out = self
.asp
.forward(&mfa_out)
.map_err(|e| FerrumError::model(format!("speaker_encoder asp: {e}")))?;
let fc_out = self
.fc
.forward(&asp_out)
.map_err(|e| FerrumError::model(format!("speaker_encoder fc: {e}")))?;
let emb = fc_out
.squeeze(2)
.map_err(|e| FerrumError::model(format!("speaker_encoder squeeze(2): {e}")))?
.squeeze(0)
.map_err(|e| FerrumError::model(format!("speaker_encoder squeeze(0): {e}")))?;
Ok(emb)
}
}
pub fn mel_spectrogram_speaker_encoder(pcm: &[f32]) -> Vec<f32> {
use rustfft::{num_complex::Complex, FftPlanner};
const N_FFT: usize = 1024;
const HOP_SIZE: usize = 256;
const WIN_SIZE: usize = 1024;
const N_MELS: usize = 128;
const N_FFT_HALF: usize = N_FFT / 2 + 1;
let mel_filters = parse_mel_filters();
let pad_size = (N_FFT - HOP_SIZE) / 2; let padded = reflect_pad_pcm(pcm, pad_size);
let n_frames = (padded.len() - N_FFT) / HOP_SIZE + 1;
let hann: Vec<f32> = (0..WIN_SIZE)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / WIN_SIZE as f32).cos()))
.collect();
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(N_FFT);
let mut magnitudes = vec![0f32; N_FFT_HALF * n_frames];
let mut buffer = vec![Complex::new(0f32, 0f32); N_FFT];
for t in 0..n_frames {
let offset = t * HOP_SIZE;
for i in 0..N_FFT {
buffer[i] = Complex::new(padded[offset + i] * hann[i], 0.0);
}
fft.process(&mut buffer);
for f in 0..N_FFT_HALF {
let mag_sq = buffer[f].re * buffer[f].re + buffer[f].im * buffer[f].im;
magnitudes[f * n_frames + t] = (mag_sq + 1e-9).sqrt();
}
}
let mut mel_spec = vec![0f32; N_MELS * n_frames];
for m in 0..N_MELS {
for t in 0..n_frames {
let mut sum = 0f32;
for f in 0..N_FFT_HALF {
sum += mel_filters[m * N_FFT_HALF + f] * magnitudes[f * n_frames + t];
}
mel_spec[m * n_frames + t] = sum;
}
}
for v in &mut mel_spec {
*v = v.max(1e-5).ln();
}
let mut output = vec![0f32; n_frames * N_MELS];
for t in 0..n_frames {
for m in 0..N_MELS {
output[t * N_MELS + m] = mel_spec[m * n_frames + t];
}
}
output
}
fn parse_mel_filters() -> Vec<f32> {
const N_MELS: usize = 128;
const N_FFT_HALF: usize = 513;
let expected = N_MELS * N_FFT_HALF;
assert_eq!(
MEL_FILTERS.len(),
expected * 4,
"mel_filters_spkenc.bin: expected {} bytes ({} x {} x 4), got {}",
expected * 4,
N_MELS,
N_FFT_HALF,
MEL_FILTERS.len()
);
let mut filters = vec![0f32; expected];
for (i, chunk) in MEL_FILTERS.chunks_exact(4).enumerate() {
filters[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
filters
}
fn reflect_pad_pcm(signal: &[f32], pad: usize) -> Vec<f32> {
let n = signal.len();
let mut out = Vec::with_capacity(n + 2 * pad);
for i in (1..=pad).rev() {
out.push(signal[i.min(n - 1)]);
}
out.extend_from_slice(signal);
for i in 1..=pad {
out.push(signal[(n - 1).saturating_sub(i)]);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reflect_pad_pcm() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let padded = reflect_pad_pcm(&signal, 2);
assert_eq!(padded, vec![3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0]);
}
#[test]
fn test_mel_filters_parse() {
let filters = parse_mel_filters();
assert_eq!(filters.len(), 128 * 513);
let nonzero = filters.iter().filter(|&&v| v != 0.0).count();
assert!(nonzero > 0, "mel filterbank should have non-zero entries");
}
#[test]
fn test_mel_spectrogram_shape() {
let pcm = vec![0.0f32; 24000];
let mel = mel_spectrogram_speaker_encoder(&pcm);
let n_frames = mel.len() / 128;
assert_eq!(mel.len() % 128, 0, "mel length should be multiple of 128");
assert!(n_frames > 0, "should have at least 1 frame");
}
#[test]
fn test_sigmoid() {
let dev = candle_core::Device::Cpu;
let x = Tensor::new(&[0.0f32, 1.0, -1.0], &dev).unwrap();
let s = sigmoid(&x).unwrap().to_vec1::<f32>().unwrap();
assert!((s[0] - 0.5).abs() < 1e-5);
assert!((s[1] - 0.7311).abs() < 1e-3);
assert!((s[2] - 0.2689).abs() < 1e-3);
}
}