1use candle_core::{Module, Tensor};
11use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
12use ferrum_types::{FerrumError, Result};
13use tracing::info;
14
15const MEL_FILTERS: &[u8] = include_bytes!("mel_filters_spkenc.bin");
18
19fn reflect_pad_1d(x: &Tensor, pad_left: usize, pad_right: usize) -> candle_core::Result<Tensor> {
27 if pad_left == 0 && pad_right == 0 {
28 return Ok(x.clone());
29 }
30 let t = x.dim(2)?;
31 let mut parts: Vec<Tensor> = Vec::new();
32
33 let x = x.contiguous()?;
34
35 if pad_left > 0 {
37 let mut left_indices = Vec::with_capacity(pad_left);
38 for i in (1..=pad_left).rev() {
39 left_indices.push(i.min(t - 1) as u32);
40 }
41 let idx = Tensor::new(left_indices, x.device())?;
42 parts.push(x.index_select(&idx, 2)?);
43 }
44
45 parts.push(x.clone());
47
48 if pad_right > 0 {
50 let mut right_indices = Vec::with_capacity(pad_right);
51 for i in 1..=pad_right {
52 right_indices.push((t - 1).saturating_sub(i) as u32);
53 }
54 let idx = Tensor::new(right_indices, x.device())?;
55 parts.push(x.index_select(&idx, 2)?);
56 }
57
58 Tensor::cat(&parts, 2)
59}
60
61struct ReflectConv1d {
63 conv: Conv1d,
64 pad_left: usize,
65 pad_right: usize,
66}
67
68impl ReflectConv1d {
69 fn load(
70 in_ch: usize,
71 out_ch: usize,
72 kernel_size: usize,
73 dilation: usize,
74 groups: usize,
75 vb: VarBuilder,
76 ) -> candle_core::Result<Self> {
77 let effective_kernel = dilation * (kernel_size - 1) + 1;
78 let total_pad = effective_kernel - 1;
79 let pad_left = total_pad / 2;
80 let pad_right = total_pad - pad_left;
81
82 let cfg = Conv1dConfig {
83 padding: 0,
84 stride: 1,
85 dilation,
86 groups,
87 cudnn_fwd_algo: None,
88 };
89 let w = vb.get((out_ch, in_ch / groups, kernel_size), "weight")?;
90 let b = vb.get(out_ch, "bias").ok();
91 Ok(Self {
92 conv: Conv1d::new(w, b, cfg),
93 pad_left,
94 pad_right,
95 })
96 }
97
98 fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
99 let x = reflect_pad_1d(x, self.pad_left, self.pad_right)?;
100 self.conv.forward(&x)
101 }
102}
103
104struct TimeDelayNetBlock {
109 conv: ReflectConv1d,
110}
111
112impl TimeDelayNetBlock {
113 fn load(
114 in_ch: usize,
115 out_ch: usize,
116 kernel_size: usize,
117 dilation: usize,
118 vb: VarBuilder,
119 ) -> candle_core::Result<Self> {
120 let conv = ReflectConv1d::load(in_ch, out_ch, kernel_size, dilation, 1, vb.pp("conv"))?;
121 Ok(Self { conv })
122 }
123
124 fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
125 self.conv.forward(x)?.relu()
126 }
127}
128
129struct Res2NetBlock {
136 scale: usize, chunk_size: usize,
138 blocks: Vec<TimeDelayNetBlock>, }
140
141impl Res2NetBlock {
142 fn load(
143 channels: usize,
144 kernel_size: usize,
145 dilation: usize,
146 scale: usize,
147 vb: VarBuilder,
148 ) -> candle_core::Result<Self> {
149 let chunk_size = channels / scale;
150 let mut blocks = Vec::with_capacity(scale - 1);
151 for j in 0..(scale - 1) {
152 let tdnn = TimeDelayNetBlock::load(
153 chunk_size,
154 chunk_size,
155 kernel_size,
156 dilation,
157 vb.pp(format!("blocks.{j}")),
158 )?;
159 blocks.push(tdnn);
160 }
161 Ok(Self {
162 scale,
163 chunk_size,
164 blocks,
165 })
166 }
167
168 fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
169 let mut outputs: Vec<Tensor> = Vec::with_capacity(self.scale);
171
172 let chunk0 = x.narrow(1, 0, self.chunk_size)?;
174 outputs.push(chunk0);
175
176 for i in 1..self.scale {
177 let chunk_i = x.narrow(1, i * self.chunk_size, self.chunk_size)?;
178 let input_i = if i == 1 {
180 chunk_i
181 } else {
182 (chunk_i + outputs.last().unwrap())?
183 };
184 let out_i = self.blocks[i - 1].forward(&input_i)?;
185 outputs.push(out_i);
186 }
187
188 Tensor::cat(&outputs, 1)
189 }
190}
191
192struct SqueezeExcitationBlock {
197 conv1: ReflectConv1d,
198 conv2: ReflectConv1d,
199}
200
201impl SqueezeExcitationBlock {
202 fn load(channels: usize, se_channels: usize, vb: VarBuilder) -> candle_core::Result<Self> {
203 let conv1 = ReflectConv1d::load(channels, se_channels, 1, 1, 1, vb.pp("conv1"))?;
204 let conv2 = ReflectConv1d::load(se_channels, channels, 1, 1, 1, vb.pp("conv2"))?;
205 Ok(Self { conv1, conv2 })
206 }
207
208 fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
209 let s = x.mean_keepdim(2)?;
211 let s = self.conv1.forward(&s)?.relu()?;
212 let s = self.conv2.forward(&s)?;
213 let s = sigmoid(&s)?;
215 x.broadcast_mul(&s)
217 }
218}
219
220fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
222 let ones = x.ones_like()?;
223 let neg = x.neg()?;
224 ones.broadcast_div(&(neg.exp()? + 1.0)?)
225}
226
227struct SERes2NetBlock {
232 tdnn1: TimeDelayNetBlock,
233 res2net_block: Res2NetBlock,
234 tdnn2: TimeDelayNetBlock,
235 se_block: SqueezeExcitationBlock,
236 shortcut: Option<ReflectConv1d>, }
238
239impl SERes2NetBlock {
240 fn load(
241 in_ch: usize,
242 out_ch: usize,
243 kernel_size: usize,
244 dilation: usize,
245 se_channels: usize,
246 res2net_scale: usize,
247 vb: VarBuilder,
248 ) -> candle_core::Result<Self> {
249 let tdnn1 = TimeDelayNetBlock::load(in_ch, out_ch, 1, 1, vb.pp("tdnn1"))?;
250 let res2net_block = Res2NetBlock::load(
251 out_ch,
252 kernel_size,
253 dilation,
254 res2net_scale,
255 vb.pp("res2net_block"),
256 )?;
257 let tdnn2 = TimeDelayNetBlock::load(out_ch, out_ch, 1, 1, vb.pp("tdnn2"))?;
258 let se_block = SqueezeExcitationBlock::load(out_ch, se_channels, vb.pp("se_block"))?;
259 let shortcut = if in_ch != out_ch {
260 Some(ReflectConv1d::load(
261 in_ch,
262 out_ch,
263 1,
264 1,
265 1,
266 vb.pp("shortcut.conv"),
267 )?)
268 } else {
269 None
270 };
271 Ok(Self {
272 tdnn1,
273 res2net_block,
274 tdnn2,
275 se_block,
276 shortcut,
277 })
278 }
279
280 fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
281 let residual = match &self.shortcut {
282 Some(sc) => sc.forward(x)?,
283 None => x.clone(),
284 };
285 let out = self.tdnn1.forward(x)?;
286 let out = self.res2net_block.forward(&out)?;
287 let out = self.tdnn2.forward(&out)?;
288 let out = self.se_block.forward(&out)?;
289 out + residual
290 }
291}
292
293struct AttentiveStatisticsPooling {
299 tdnn: TimeDelayNetBlock, conv: ReflectConv1d, }
302
303impl AttentiveStatisticsPooling {
304 fn load(
305 channels: usize,
306 attention_channels: usize,
307 vb: VarBuilder,
308 ) -> candle_core::Result<Self> {
309 let tdnn = TimeDelayNetBlock::load(channels * 3, attention_channels, 1, 1, vb.pp("tdnn"))?;
310 let conv = ReflectConv1d::load(attention_channels, channels, 1, 1, 1, vb.pp("conv"))?;
311 Ok(Self { tdnn, conv })
312 }
313
314 fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
315 let mean = x.mean_keepdim(2)?; let diff = x.broadcast_sub(&mean)?;
320 let var = diff.sqr()?.mean_keepdim(2)?;
321 let std = (var + 1e-5)?.sqrt()?; let mean_exp = mean.expand(x.dims())?; let std_exp = std.expand(x.dims())?; let cat = Tensor::cat(&[x, &mean_exp, &std_exp], 1)?;
329
330 let attn = self.tdnn.forward(&cat)?; let attn = attn.tanh()?;
333 let attn = self.conv.forward(&attn)?; let attn = softmax_dim2(&attn)?; let weighted = (x * &attn)?;
340 let w_mean = weighted.sum_keepdim(2)?; let w_diff = x.broadcast_sub(&w_mean)?;
343 let w_var = (w_diff.sqr()? * &attn)?.sum_keepdim(2)?;
344 let w_std = (w_var + 1e-5)?.sqrt()?; Tensor::cat(&[&w_mean, &w_std], 1)
348 }
349}
350
351fn softmax_dim2(x: &Tensor) -> candle_core::Result<Tensor> {
353 let max = x.max_keepdim(2)?;
354 let shifted = x.broadcast_sub(&max)?;
355 let exp = shifted.exp()?;
356 let sum = exp.sum_keepdim(2)?;
357 exp.broadcast_div(&sum)
358}
359
360pub struct SpeakerEncoder {
371 block0: TimeDelayNetBlock,
372 se_blocks: Vec<SERes2NetBlock>, mfa: TimeDelayNetBlock,
374 asp: AttentiveStatisticsPooling,
375 fc: ReflectConv1d,
376}
377
378impl SpeakerEncoder {
379 pub fn load_with_dim(vb: VarBuilder, enc_dim: usize) -> Result<Self> {
382 info!("Loading ECAPA-TDNN speaker encoder");
383
384 let block0 = TimeDelayNetBlock::load(128, 512, 5, 1, vb.pp("blocks.0"))
386 .map_err(|e| FerrumError::model(format!("speaker_encoder blocks.0: {e}")))?;
387
388 let mut se_blocks = Vec::with_capacity(3);
390 for (i, dilation) in [(1usize, 2usize), (2, 3), (3, 4)] {
391 let blk = SERes2NetBlock::load(
392 512, 512, 3, dilation,
396 128, 8, vb.pp(format!("blocks.{i}")),
399 )
400 .map_err(|e| FerrumError::model(format!("speaker_encoder blocks.{i}: {e}")))?;
401 se_blocks.push(blk);
402 }
403
404 let mfa = TimeDelayNetBlock::load(1536, 1536, 1, 1, vb.pp("mfa"))
406 .map_err(|e| FerrumError::model(format!("speaker_encoder mfa: {e}")))?;
407
408 let asp = AttentiveStatisticsPooling::load(1536, 128, vb.pp("asp"))
410 .map_err(|e| FerrumError::model(format!("speaker_encoder asp: {e}")))?;
411
412 let fc = ReflectConv1d::load(3072, enc_dim, 1, 1, 1, vb.pp("fc"))
414 .map_err(|e| FerrumError::model(format!("speaker_encoder fc: {e}")))?;
415
416 info!(
417 "Speaker encoder loaded (ECAPA-TDNN, {}-dim output)",
418 enc_dim
419 );
420 Ok(Self {
421 block0,
422 se_blocks,
423 mfa,
424 asp,
425 fc,
426 })
427 }
428
429 pub fn forward(&self, mel: &Tensor) -> Result<Tensor> {
434 let x = mel
436 .transpose(1, 2)
437 .and_then(|t| t.contiguous())
438 .map_err(|e| FerrumError::model(format!("speaker_encoder transpose: {e}")))?;
439
440 let x = self
442 .block0
443 .forward(&x)
444 .map_err(|e| FerrumError::model(format!("speaker_encoder block0: {e}")))?;
445
446 let mut se_outputs = Vec::with_capacity(3);
448 let mut x = x;
449 for (i, blk) in self.se_blocks.iter().enumerate() {
450 x = blk
451 .forward(&x)
452 .map_err(|e| FerrumError::model(format!("speaker_encoder se_block[{i}]: {e}")))?;
453 se_outputs.push(x.clone());
454 }
455
456 let mfa_in = Tensor::cat(&se_outputs, 1)
458 .map_err(|e| FerrumError::model(format!("speaker_encoder mfa cat: {e}")))?;
459 let mfa_out = self
460 .mfa
461 .forward(&mfa_in)
462 .map_err(|e| FerrumError::model(format!("speaker_encoder mfa: {e}")))?;
463
464 let asp_out = self
466 .asp
467 .forward(&mfa_out)
468 .map_err(|e| FerrumError::model(format!("speaker_encoder asp: {e}")))?;
469
470 let fc_out = self
472 .fc
473 .forward(&asp_out)
474 .map_err(|e| FerrumError::model(format!("speaker_encoder fc: {e}")))?;
475
476 let emb = fc_out
478 .squeeze(2)
479 .map_err(|e| FerrumError::model(format!("speaker_encoder squeeze(2): {e}")))?
480 .squeeze(0)
481 .map_err(|e| FerrumError::model(format!("speaker_encoder squeeze(0): {e}")))?;
482
483 Ok(emb)
484 }
485}
486
487pub fn mel_spectrogram_speaker_encoder(pcm: &[f32]) -> Vec<f32> {
502 use rustfft::{num_complex::Complex, FftPlanner};
503
504 const N_FFT: usize = 1024;
505 const HOP_SIZE: usize = 256;
506 const WIN_SIZE: usize = 1024;
507 const N_MELS: usize = 128;
508 const N_FFT_HALF: usize = N_FFT / 2 + 1; let mel_filters = parse_mel_filters();
512
513 let pad_size = (N_FFT - HOP_SIZE) / 2; let padded = reflect_pad_pcm(pcm, pad_size);
516
517 let n_frames = (padded.len() - N_FFT) / HOP_SIZE + 1;
519
520 let hann: Vec<f32> = (0..WIN_SIZE)
522 .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / WIN_SIZE as f32).cos()))
523 .collect();
524
525 let mut planner = FftPlanner::<f32>::new();
526 let fft = planner.plan_fft_forward(N_FFT);
527
528 let mut magnitudes = vec![0f32; N_FFT_HALF * n_frames];
531 let mut buffer = vec![Complex::new(0f32, 0f32); N_FFT];
532
533 for t in 0..n_frames {
534 let offset = t * HOP_SIZE;
535 for i in 0..N_FFT {
536 buffer[i] = Complex::new(padded[offset + i] * hann[i], 0.0);
537 }
538 fft.process(&mut buffer);
539 for f in 0..N_FFT_HALF {
540 let mag_sq = buffer[f].re * buffer[f].re + buffer[f].im * buffer[f].im;
541 magnitudes[f * n_frames + t] = (mag_sq + 1e-9).sqrt();
542 }
543 }
544
545 let mut mel_spec = vec![0f32; N_MELS * n_frames];
548 for m in 0..N_MELS {
549 for t in 0..n_frames {
550 let mut sum = 0f32;
551 for f in 0..N_FFT_HALF {
552 sum += mel_filters[m * N_FFT_HALF + f] * magnitudes[f * n_frames + t];
553 }
554 mel_spec[m * n_frames + t] = sum;
555 }
556 }
557
558 for v in &mut mel_spec {
560 *v = v.max(1e-5).ln();
561 }
562
563 let mut output = vec![0f32; n_frames * N_MELS];
565 for t in 0..n_frames {
566 for m in 0..N_MELS {
567 output[t * N_MELS + m] = mel_spec[m * n_frames + t];
568 }
569 }
570
571 output
572}
573
574fn parse_mel_filters() -> Vec<f32> {
577 const N_MELS: usize = 128;
578 const N_FFT_HALF: usize = 513;
579 let expected = N_MELS * N_FFT_HALF;
580
581 assert_eq!(
582 MEL_FILTERS.len(),
583 expected * 4,
584 "mel_filters_spkenc.bin: expected {} bytes ({} x {} x 4), got {}",
585 expected * 4,
586 N_MELS,
587 N_FFT_HALF,
588 MEL_FILTERS.len()
589 );
590
591 let mut filters = vec![0f32; expected];
592 for (i, chunk) in MEL_FILTERS.chunks_exact(4).enumerate() {
593 filters[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
594 }
595 filters
596}
597
598fn reflect_pad_pcm(signal: &[f32], pad: usize) -> Vec<f32> {
600 let n = signal.len();
601 let mut out = Vec::with_capacity(n + 2 * pad);
602 for i in (1..=pad).rev() {
604 out.push(signal[i.min(n - 1)]);
605 }
606 out.extend_from_slice(signal);
607 for i in 1..=pad {
609 out.push(signal[(n - 1).saturating_sub(i)]);
610 }
611 out
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617
618 #[test]
619 fn test_reflect_pad_pcm() {
620 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
621 let padded = reflect_pad_pcm(&signal, 2);
622 assert_eq!(padded, vec![3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0]);
625 }
626
627 #[test]
628 fn test_mel_filters_parse() {
629 let filters = parse_mel_filters();
630 assert_eq!(filters.len(), 128 * 513);
631 let nonzero = filters.iter().filter(|&&v| v != 0.0).count();
633 assert!(nonzero > 0, "mel filterbank should have non-zero entries");
634 }
635
636 #[test]
637 fn test_mel_spectrogram_shape() {
638 let pcm = vec![0.0f32; 24000];
640 let mel = mel_spectrogram_speaker_encoder(&pcm);
641 let n_frames = mel.len() / 128;
644 assert_eq!(mel.len() % 128, 0, "mel length should be multiple of 128");
645 assert!(n_frames > 0, "should have at least 1 frame");
646 }
647
648 #[test]
649 fn test_sigmoid() {
650 let dev = candle_core::Device::Cpu;
651 let x = Tensor::new(&[0.0f32, 1.0, -1.0], &dev).unwrap();
652 let s = sigmoid(&x).unwrap().to_vec1::<f32>().unwrap();
653 assert!((s[0] - 0.5).abs() < 1e-5);
654 assert!((s[1] - 0.7311).abs() < 1e-3);
655 assert!((s[2] - 0.2689).abs() < 1e-3);
656 }
657}