Skip to main content

ferrum_models/architectures/
speaker_encoder.rs

1//! ECAPA-TDNN Speaker Encoder for Qwen3-TTS voice cloning.
2//!
3//! Takes a mel spectrogram [1, T, 128] and outputs a 1024-dim speaker embedding.
4//! Used for zero-shot voice cloning by conditioning the TTS Talker on reference audio.
5//!
6//! Architecture: TDNN → 3x SE-Res2Net blocks → MFA → ASP → FC → [1024]
7//!
8//! Weight prefix in safetensors: `speaker_encoder.`
9
10use candle_core::{Module, Tensor};
11use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
12use ferrum_types::{FerrumError, Result};
13use tracing::info;
14
15// ── Mel filterbank for speaker encoder (128 x 513, row-major f32le) ──────
16
17const MEL_FILTERS: &[u8] = include_bytes!("mel_filters_spkenc.bin");
18
19// ── Reflect-pad Conv1d (Metal-compatible) ─────────────────────────────────
20//
21// candle's Conv1d only supports zero-padding. For "same" convolution with
22// reflect padding mode, we manually reflect-pad the input tensor and run
23// Conv1d with padding=0.
24
25/// Reflect-pad a 3D tensor [B, C, T] along the time (last) dimension.
26fn reflect_pad_1d(x: &Tensor, pad_left: usize, pad_right: usize) -> candle_core::Result<Tensor> {
27    if pad_left == 0 && pad_right == 0 {
28        return Ok(x.clone());
29    }
30    let t = x.dim(2)?;
31    let mut parts: Vec<Tensor> = Vec::new();
32
33    let x = x.contiguous()?;
34
35    // Left reflection: indices pad_left, pad_left-1, ..., 1
36    if pad_left > 0 {
37        let mut left_indices = Vec::with_capacity(pad_left);
38        for i in (1..=pad_left).rev() {
39            left_indices.push(i.min(t - 1) as u32);
40        }
41        let idx = Tensor::new(left_indices, x.device())?;
42        parts.push(x.index_select(&idx, 2)?);
43    }
44
45    // Original
46    parts.push(x.clone());
47
48    // Right reflection: indices t-2, t-3, ..., t-1-pad_right
49    if pad_right > 0 {
50        let mut right_indices = Vec::with_capacity(pad_right);
51        for i in 1..=pad_right {
52            right_indices.push((t - 1).saturating_sub(i) as u32);
53        }
54        let idx = Tensor::new(right_indices, x.device())?;
55        parts.push(x.index_select(&idx, 2)?);
56    }
57
58    Tensor::cat(&parts, 2)
59}
60
61/// Conv1d with reflect padding ("same" mode with padding_mode="reflect").
62struct ReflectConv1d {
63    conv: Conv1d,
64    pad_left: usize,
65    pad_right: usize,
66}
67
68impl ReflectConv1d {
69    fn load(
70        in_ch: usize,
71        out_ch: usize,
72        kernel_size: usize,
73        dilation: usize,
74        groups: usize,
75        vb: VarBuilder,
76    ) -> candle_core::Result<Self> {
77        let effective_kernel = dilation * (kernel_size - 1) + 1;
78        let total_pad = effective_kernel - 1;
79        let pad_left = total_pad / 2;
80        let pad_right = total_pad - pad_left;
81
82        let cfg = Conv1dConfig {
83            padding: 0,
84            stride: 1,
85            dilation,
86            groups,
87            cudnn_fwd_algo: None,
88        };
89        let w = vb.get((out_ch, in_ch / groups, kernel_size), "weight")?;
90        let b = vb.get(out_ch, "bias").ok();
91        Ok(Self {
92            conv: Conv1d::new(w, b, cfg),
93            pad_left,
94            pad_right,
95        })
96    }
97
98    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
99        let x = reflect_pad_1d(x, self.pad_left, self.pad_right)?;
100        self.conv.forward(&x)
101    }
102}
103
104// ── TimeDelayNetBlock (TDNN) ──────────────────────────────────────────────
105//
106// Conv1d(in→out, kernel, dilation, padding="same", padding_mode="reflect") + ReLU
107
108struct TimeDelayNetBlock {
109    conv: ReflectConv1d,
110}
111
112impl TimeDelayNetBlock {
113    fn load(
114        in_ch: usize,
115        out_ch: usize,
116        kernel_size: usize,
117        dilation: usize,
118        vb: VarBuilder,
119    ) -> candle_core::Result<Self> {
120        let conv = ReflectConv1d::load(in_ch, out_ch, kernel_size, dilation, 1, vb.pp("conv"))?;
121        Ok(Self { conv })
122    }
123
124    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
125        self.conv.forward(x)?.relu()
126    }
127}
128
129// ── Res2NetBlock ──────────────────────────────────────────────────────────
130//
131// Splits input into 8 chunks along channel dim. First chunk passes through
132// unchanged. Chunks 1..7 go through TDNN blocks with cumulative addition
133// from the previous chunk's output.
134
135struct Res2NetBlock {
136    scale: usize, // 8
137    chunk_size: usize,
138    blocks: Vec<TimeDelayNetBlock>, // 7 blocks (indices 0..6 → chunks 1..7)
139}
140
141impl Res2NetBlock {
142    fn load(
143        channels: usize,
144        kernel_size: usize,
145        dilation: usize,
146        scale: usize,
147        vb: VarBuilder,
148    ) -> candle_core::Result<Self> {
149        let chunk_size = channels / scale;
150        let mut blocks = Vec::with_capacity(scale - 1);
151        for j in 0..(scale - 1) {
152            let tdnn = TimeDelayNetBlock::load(
153                chunk_size,
154                chunk_size,
155                kernel_size,
156                dilation,
157                vb.pp(format!("blocks.{j}")),
158            )?;
159            blocks.push(tdnn);
160        }
161        Ok(Self {
162            scale,
163            chunk_size,
164            blocks,
165        })
166    }
167
168    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
169        // x: [B, C, T] — split along channel dim into `scale` chunks
170        let mut outputs: Vec<Tensor> = Vec::with_capacity(self.scale);
171
172        // chunk[0]: pass through (no conv)
173        let chunk0 = x.narrow(1, 0, self.chunk_size)?;
174        outputs.push(chunk0);
175
176        for i in 1..self.scale {
177            let chunk_i = x.narrow(1, i * self.chunk_size, self.chunk_size)?;
178            // First block (i=1) processes chunk directly; subsequent blocks add previous output
179            let input_i = if i == 1 {
180                chunk_i
181            } else {
182                (chunk_i + outputs.last().unwrap())?
183            };
184            let out_i = self.blocks[i - 1].forward(&input_i)?;
185            outputs.push(out_i);
186        }
187
188        Tensor::cat(&outputs, 1)
189    }
190}
191
192// ── SqueezeExcitationBlock ────────────────────────────────────────────────
193//
194// Global avg pool → Conv1d(ch→se_ch, k=1) + ReLU → Conv1d(se_ch→ch, k=1) + Sigmoid → multiply
195
196struct SqueezeExcitationBlock {
197    conv1: ReflectConv1d,
198    conv2: ReflectConv1d,
199}
200
201impl SqueezeExcitationBlock {
202    fn load(channels: usize, se_channels: usize, vb: VarBuilder) -> candle_core::Result<Self> {
203        let conv1 = ReflectConv1d::load(channels, se_channels, 1, 1, 1, vb.pp("conv1"))?;
204        let conv2 = ReflectConv1d::load(se_channels, channels, 1, 1, 1, vb.pp("conv2"))?;
205        Ok(Self { conv1, conv2 })
206    }
207
208    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
209        // Global average pooling over time: [B, C, T] → [B, C, 1]
210        let s = x.mean_keepdim(2)?;
211        let s = self.conv1.forward(&s)?.relu()?;
212        let s = self.conv2.forward(&s)?;
213        // Sigmoid
214        let s = sigmoid(&s)?;
215        // Channel-wise scale
216        x.broadcast_mul(&s)
217    }
218}
219
220/// Manual sigmoid: 1 / (1 + exp(-x))
221fn sigmoid(x: &Tensor) -> candle_core::Result<Tensor> {
222    let ones = x.ones_like()?;
223    let neg = x.neg()?;
224    ones.broadcast_div(&(neg.exp()? + 1.0)?)
225}
226
227// ── SqueezeExcitationRes2NetBlock ─────────────────────────────────────────
228//
229// TDNN1(in→out, k=1) + Res2Net(out, k=3, dilation) + TDNN2(out→out, k=1) + SE + residual
230
231struct SERes2NetBlock {
232    tdnn1: TimeDelayNetBlock,
233    res2net_block: Res2NetBlock,
234    tdnn2: TimeDelayNetBlock,
235    se_block: SqueezeExcitationBlock,
236    shortcut: Option<ReflectConv1d>, // only if in_ch != out_ch
237}
238
239impl SERes2NetBlock {
240    fn load(
241        in_ch: usize,
242        out_ch: usize,
243        kernel_size: usize,
244        dilation: usize,
245        se_channels: usize,
246        res2net_scale: usize,
247        vb: VarBuilder,
248    ) -> candle_core::Result<Self> {
249        let tdnn1 = TimeDelayNetBlock::load(in_ch, out_ch, 1, 1, vb.pp("tdnn1"))?;
250        let res2net_block = Res2NetBlock::load(
251            out_ch,
252            kernel_size,
253            dilation,
254            res2net_scale,
255            vb.pp("res2net_block"),
256        )?;
257        let tdnn2 = TimeDelayNetBlock::load(out_ch, out_ch, 1, 1, vb.pp("tdnn2"))?;
258        let se_block = SqueezeExcitationBlock::load(out_ch, se_channels, vb.pp("se_block"))?;
259        let shortcut = if in_ch != out_ch {
260            Some(ReflectConv1d::load(
261                in_ch,
262                out_ch,
263                1,
264                1,
265                1,
266                vb.pp("shortcut.conv"),
267            )?)
268        } else {
269            None
270        };
271        Ok(Self {
272            tdnn1,
273            res2net_block,
274            tdnn2,
275            se_block,
276            shortcut,
277        })
278    }
279
280    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
281        let residual = match &self.shortcut {
282            Some(sc) => sc.forward(x)?,
283            None => x.clone(),
284        };
285        let out = self.tdnn1.forward(x)?;
286        let out = self.res2net_block.forward(&out)?;
287        let out = self.tdnn2.forward(&out)?;
288        let out = self.se_block.forward(&out)?;
289        out + residual
290    }
291}
292
293// ── Attentive Statistics Pooling ──────────────────────────────────────────
294//
295// Compute attention-weighted mean and std over time dimension.
296// Input: [B, C, T]  Output: [B, C*2, 1]
297
298struct AttentiveStatisticsPooling {
299    tdnn: TimeDelayNetBlock, // Conv1d(ch*3 → attention_ch, k=1)
300    conv: ReflectConv1d,     // Conv1d(attention_ch → ch, k=1)
301}
302
303impl AttentiveStatisticsPooling {
304    fn load(
305        channels: usize,
306        attention_channels: usize,
307        vb: VarBuilder,
308    ) -> candle_core::Result<Self> {
309        let tdnn = TimeDelayNetBlock::load(channels * 3, attention_channels, 1, 1, vb.pp("tdnn"))?;
310        let conv = ReflectConv1d::load(attention_channels, channels, 1, 1, 1, vb.pp("conv"))?;
311        Ok(Self { tdnn, conv })
312    }
313
314    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
315        // x: [B, C, T]
316
317        // Compute mean and std over time (full mask — no padding)
318        let mean = x.mean_keepdim(2)?; // [B, C, 1]
319        let diff = x.broadcast_sub(&mean)?;
320        let var = diff.sqr()?.mean_keepdim(2)?;
321        let std = (var + 1e-5)?.sqrt()?; // [B, C, 1]
322
323        // Expand mean/std to match time dim
324        let mean_exp = mean.expand(x.dims())?; // [B, C, T]
325        let std_exp = std.expand(x.dims())?; // [B, C, T]
326
327        // Concat [x, mean, std] along channel dim → [B, C*3, T]
328        let cat = Tensor::cat(&[x, &mean_exp, &std_exp], 1)?;
329
330        // Attention weights
331        let attn = self.tdnn.forward(&cat)?; // [B, attn_ch, T] (includes ReLU)
332        let attn = attn.tanh()?;
333        let attn = self.conv.forward(&attn)?; // [B, C, T]
334
335        // Softmax over time dimension
336        let attn = softmax_dim2(&attn)?; // [B, C, T]
337
338        // Weighted statistics
339        let weighted = (x * &attn)?;
340        let w_mean = weighted.sum_keepdim(2)?; // [B, C, 1]
341
342        let w_diff = x.broadcast_sub(&w_mean)?;
343        let w_var = (w_diff.sqr()? * &attn)?.sum_keepdim(2)?;
344        let w_std = (w_var + 1e-5)?.sqrt()?; // [B, C, 1]
345
346        // Concat mean + std → [B, C*2, 1]
347        Tensor::cat(&[&w_mean, &w_std], 1)
348    }
349}
350
351/// Softmax over dim 2 (time dimension for [B, C, T] tensors).
352fn softmax_dim2(x: &Tensor) -> candle_core::Result<Tensor> {
353    let max = x.max_keepdim(2)?;
354    let shifted = x.broadcast_sub(&max)?;
355    let exp = shifted.exp()?;
356    let sum = exp.sum_keepdim(2)?;
357    exp.broadcast_div(&sum)
358}
359
360// ── SpeakerEncoder (full model) ───────────────────────────────────────────
361//
362// blocks[0]: TDNN(128→512, k=5, dilation=1)
363// blocks[1]: SE-Res2Net(512→512, k=3, dilation=2)
364// blocks[2]: SE-Res2Net(512→512, k=3, dilation=3)
365// blocks[3]: SE-Res2Net(512→512, k=3, dilation=4)
366// MFA: concat(blocks[1..3] outputs) → TDNN(1536→1536, k=1)
367// ASP: AttentiveStatisticsPooling(1536, attention_channels=128)
368// FC: Conv1d(3072→1024, k=1) with reflect padding
369
370pub struct SpeakerEncoder {
371    block0: TimeDelayNetBlock,
372    se_blocks: Vec<SERes2NetBlock>, // 3 blocks
373    mfa: TimeDelayNetBlock,
374    asp: AttentiveStatisticsPooling,
375    fc: ReflectConv1d,
376}
377
378impl SpeakerEncoder {
379    /// Load speaker encoder weights from VarBuilder.
380    /// Expects the VarBuilder to be scoped to the `speaker_encoder` prefix.
381    pub fn load_with_dim(vb: VarBuilder, enc_dim: usize) -> Result<Self> {
382        info!("Loading ECAPA-TDNN speaker encoder");
383
384        // blocks.0: TDNN(128→512, k=5, dilation=1)
385        let block0 = TimeDelayNetBlock::load(128, 512, 5, 1, vb.pp("blocks.0"))
386            .map_err(|e| FerrumError::model(format!("speaker_encoder blocks.0: {e}")))?;
387
388        // blocks.1-3: SE-Res2Net(512→512, k=3, dilation=2,3,4)
389        let mut se_blocks = Vec::with_capacity(3);
390        for (i, dilation) in [(1usize, 2usize), (2, 3), (3, 4)] {
391            let blk = SERes2NetBlock::load(
392                512, // in_ch
393                512, // out_ch
394                3,   // kernel_size
395                dilation,
396                128, // se_channels
397                8,   // res2net_scale
398                vb.pp(format!("blocks.{i}")),
399            )
400            .map_err(|e| FerrumError::model(format!("speaker_encoder blocks.{i}: {e}")))?;
401            se_blocks.push(blk);
402        }
403
404        // MFA: TDNN(1536→1536, k=1, dilation=1)
405        let mfa = TimeDelayNetBlock::load(1536, 1536, 1, 1, vb.pp("mfa"))
406            .map_err(|e| FerrumError::model(format!("speaker_encoder mfa: {e}")))?;
407
408        // ASP: AttentiveStatisticsPooling(1536, attention_channels=128)
409        let asp = AttentiveStatisticsPooling::load(1536, 128, vb.pp("asp"))
410            .map_err(|e| FerrumError::model(format!("speaker_encoder asp: {e}")))?;
411
412        // FC: Conv1d(3072→enc_dim, k=1) — enc_dim is 1024 for 0.6B, 2048 for 1.7B
413        let fc = ReflectConv1d::load(3072, enc_dim, 1, 1, 1, vb.pp("fc"))
414            .map_err(|e| FerrumError::model(format!("speaker_encoder fc: {e}")))?;
415
416        info!(
417            "Speaker encoder loaded (ECAPA-TDNN, {}-dim output)",
418            enc_dim
419        );
420        Ok(Self {
421            block0,
422            se_blocks,
423            mfa,
424            asp,
425            fc,
426        })
427    }
428
429    /// Forward pass: mel spectrogram → speaker embedding.
430    ///
431    /// - `mel`: [1, T, 128] mel spectrogram tensor
432    /// - Returns: [1024] speaker embedding vector
433    pub fn forward(&self, mel: &Tensor) -> Result<Tensor> {
434        // Transpose [1, T, 128] → [1, 128, T] for Conv1d processing
435        let x = mel
436            .transpose(1, 2)
437            .and_then(|t| t.contiguous())
438            .map_err(|e| FerrumError::model(format!("speaker_encoder transpose: {e}")))?;
439
440        // blocks[0]: initial TDNN
441        let x = self
442            .block0
443            .forward(&x)
444            .map_err(|e| FerrumError::model(format!("speaker_encoder block0: {e}")))?;
445
446        // blocks[1..3]: SE-Res2Net blocks (collect outputs for MFA)
447        let mut se_outputs = Vec::with_capacity(3);
448        let mut x = x;
449        for (i, blk) in self.se_blocks.iter().enumerate() {
450            x = blk
451                .forward(&x)
452                .map_err(|e| FerrumError::model(format!("speaker_encoder se_block[{i}]: {e}")))?;
453            se_outputs.push(x.clone());
454        }
455
456        // MFA: concat SE block outputs along channel dim → TDNN
457        let mfa_in = Tensor::cat(&se_outputs, 1)
458            .map_err(|e| FerrumError::model(format!("speaker_encoder mfa cat: {e}")))?;
459        let mfa_out = self
460            .mfa
461            .forward(&mfa_in)
462            .map_err(|e| FerrumError::model(format!("speaker_encoder mfa: {e}")))?;
463
464        // ASP: [B, 1536, T] → [B, 3072, 1]
465        let asp_out = self
466            .asp
467            .forward(&mfa_out)
468            .map_err(|e| FerrumError::model(format!("speaker_encoder asp: {e}")))?;
469
470        // FC: [B, 3072, 1] → [B, 1024, 1]
471        let fc_out = self
472            .fc
473            .forward(&asp_out)
474            .map_err(|e| FerrumError::model(format!("speaker_encoder fc: {e}")))?;
475
476        // Squeeze: [B, 1024, 1] → [1024]
477        let emb = fc_out
478            .squeeze(2)
479            .map_err(|e| FerrumError::model(format!("speaker_encoder squeeze(2): {e}")))?
480            .squeeze(0)
481            .map_err(|e| FerrumError::model(format!("speaker_encoder squeeze(0): {e}")))?;
482
483        Ok(emb)
484    }
485}
486
487// ── Mel spectrogram for speaker encoder ───────────────────────────────────
488//
489// Different from Whisper mel:
490//   - 24kHz sample rate, n_fft=1024, hop=256, n_mels=128, fmin=0, fmax=12000
491//   - Magnitude (NOT squared): sqrt(re^2 + im^2 + 1e-9)
492//   - Log compression: log(clamp(x, 1e-5)) (NOT log10)
493//   - No normalization (no max-8.0 clamp, no +4/4 scaling)
494//   - Returns [1, T, 128] in row-major (time x mels)
495
496/// Compute mel spectrogram for the speaker encoder.
497///
498/// - `pcm`: audio samples (f32, 24kHz mono)
499/// - Returns flat `Vec<f32>` in [T, 128] layout (row-major, time-first)
500///   suitable for creating a [1, T, 128] tensor.
501pub fn mel_spectrogram_speaker_encoder(pcm: &[f32]) -> Vec<f32> {
502    use rustfft::{num_complex::Complex, FftPlanner};
503
504    const N_FFT: usize = 1024;
505    const HOP_SIZE: usize = 256;
506    const WIN_SIZE: usize = 1024;
507    const N_MELS: usize = 128;
508    const N_FFT_HALF: usize = N_FFT / 2 + 1; // 513
509
510    // Parse mel filterbank from embedded binary (128 x 513, f32le row-major)
511    let mel_filters = parse_mel_filters();
512
513    // Step 1: Reflect-pad for center=False with (n_fft - hop_size) / 2 on each side
514    let pad_size = (N_FFT - HOP_SIZE) / 2; // 384
515    let padded = reflect_pad_pcm(pcm, pad_size);
516
517    // Step 2: STFT with Hann window
518    let n_frames = (padded.len() - N_FFT) / HOP_SIZE + 1;
519
520    // Hann window (periodic: 2*pi*i / win_size, NOT win_size-1)
521    let hann: Vec<f32> = (0..WIN_SIZE)
522        .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / WIN_SIZE as f32).cos()))
523        .collect();
524
525    let mut planner = FftPlanner::<f32>::new();
526    let fft = planner.plan_fft_forward(N_FFT);
527
528    // Compute magnitude spectrogram: sqrt(re^2 + im^2 + 1e-9) (NOT squared)
529    // Layout: [N_FFT_HALF, n_frames] column-major (freq x time)
530    let mut magnitudes = vec![0f32; N_FFT_HALF * n_frames];
531    let mut buffer = vec![Complex::new(0f32, 0f32); N_FFT];
532
533    for t in 0..n_frames {
534        let offset = t * HOP_SIZE;
535        for i in 0..N_FFT {
536            buffer[i] = Complex::new(padded[offset + i] * hann[i], 0.0);
537        }
538        fft.process(&mut buffer);
539        for f in 0..N_FFT_HALF {
540            let mag_sq = buffer[f].re * buffer[f].re + buffer[f].im * buffer[f].im;
541            magnitudes[f * n_frames + t] = (mag_sq + 1e-9).sqrt();
542        }
543    }
544
545    // Step 3: Mel projection: mel_filters[N_MELS, N_FFT_HALF] @ mag[N_FFT_HALF, n_frames]
546    // Result: [N_MELS, n_frames]
547    let mut mel_spec = vec![0f32; N_MELS * n_frames];
548    for m in 0..N_MELS {
549        for t in 0..n_frames {
550            let mut sum = 0f32;
551            for f in 0..N_FFT_HALF {
552                sum += mel_filters[m * N_FFT_HALF + f] * magnitudes[f * n_frames + t];
553            }
554            mel_spec[m * n_frames + t] = sum;
555        }
556    }
557
558    // Step 4: Log compression: log(clamp(x, min=1e-5))
559    for v in &mut mel_spec {
560        *v = v.max(1e-5).ln();
561    }
562
563    // Step 5: Transpose from [N_MELS, n_frames] to [n_frames, N_MELS] (time x mels)
564    let mut output = vec![0f32; n_frames * N_MELS];
565    for t in 0..n_frames {
566        for m in 0..N_MELS {
567            output[t * N_MELS + m] = mel_spec[m * n_frames + t];
568        }
569    }
570
571    output
572}
573
574/// Parse mel filterbank from embedded binary data.
575/// Shape: [128, 513], stored as f32 little-endian row-major.
576fn parse_mel_filters() -> Vec<f32> {
577    const N_MELS: usize = 128;
578    const N_FFT_HALF: usize = 513;
579    let expected = N_MELS * N_FFT_HALF;
580
581    assert_eq!(
582        MEL_FILTERS.len(),
583        expected * 4,
584        "mel_filters_spkenc.bin: expected {} bytes ({} x {} x 4), got {}",
585        expected * 4,
586        N_MELS,
587        N_FFT_HALF,
588        MEL_FILTERS.len()
589    );
590
591    let mut filters = vec![0f32; expected];
592    for (i, chunk) in MEL_FILTERS.chunks_exact(4).enumerate() {
593        filters[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
594    }
595    filters
596}
597
598/// Reflect-pad PCM signal on both sides.
599fn reflect_pad_pcm(signal: &[f32], pad: usize) -> Vec<f32> {
600    let n = signal.len();
601    let mut out = Vec::with_capacity(n + 2 * pad);
602    // Left reflection: signal[pad], signal[pad-1], ..., signal[1]
603    for i in (1..=pad).rev() {
604        out.push(signal[i.min(n - 1)]);
605    }
606    out.extend_from_slice(signal);
607    // Right reflection: signal[n-2], signal[n-3], ..., signal[n-1-pad]
608    for i in 1..=pad {
609        out.push(signal[(n - 1).saturating_sub(i)]);
610    }
611    out
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617
618    #[test]
619    fn test_reflect_pad_pcm() {
620        let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
621        let padded = reflect_pad_pcm(&signal, 2);
622        // Left: signal[2], signal[1] = 3.0, 2.0
623        // Right: signal[3], signal[2] = 4.0, 3.0
624        assert_eq!(padded, vec![3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0]);
625    }
626
627    #[test]
628    fn test_mel_filters_parse() {
629        let filters = parse_mel_filters();
630        assert_eq!(filters.len(), 128 * 513);
631        // Should contain some non-zero values (it's a real filterbank)
632        let nonzero = filters.iter().filter(|&&v| v != 0.0).count();
633        assert!(nonzero > 0, "mel filterbank should have non-zero entries");
634    }
635
636    #[test]
637    fn test_mel_spectrogram_shape() {
638        // 1 second of silence at 24kHz
639        let pcm = vec![0.0f32; 24000];
640        let mel = mel_spectrogram_speaker_encoder(&pcm);
641        // n_frames = (24000 + 2*384 - 1024) / 256 + 1
642        //          = (24768 - 1024) / 256 + 1 = 23744 / 256 + 1 = 92 + 1 = 93
643        let n_frames = mel.len() / 128;
644        assert_eq!(mel.len() % 128, 0, "mel length should be multiple of 128");
645        assert!(n_frames > 0, "should have at least 1 frame");
646    }
647
648    #[test]
649    fn test_sigmoid() {
650        let dev = candle_core::Device::Cpu;
651        let x = Tensor::new(&[0.0f32, 1.0, -1.0], &dev).unwrap();
652        let s = sigmoid(&x).unwrap().to_vec1::<f32>().unwrap();
653        assert!((s[0] - 0.5).abs() < 1e-5);
654        assert!((s[1] - 0.7311).abs() < 1e-3);
655        assert!((s[2] - 0.2689).abs() < 1e-3);
656    }
657}