Skip to main content

any_tts/
mel.rs

1//! Mel spectrogram extraction for voice cloning and audio analysis.
2//!
3//! Provides a pure-candle implementation of Short-Time Fourier Transform (STFT)
4//! and mel filterbank projection. No external FFT library required — the DFT is
5//! computed via matrix multiplication with precomputed basis vectors.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use any_tts::mel::{MelConfig, MelSpectrogram};
11//! use candle_core::Device;
12//!
13//! let mel = MelSpectrogram::new(MelConfig::kokoro(), &Device::Cpu)?;
14//! let audio = candle_core::Tensor::zeros(24000, candle_core::DType::F32, &Device::Cpu)?;
15//! let spectrogram = mel.compute(&audio)?;
16//! // spectrogram shape: [1, 80, num_frames]
17//! ```
18
19use candle_core::{DType, Device, Tensor};
20
21use crate::error::{TtsError, TtsResult};
22
23// ---------------------------------------------------------------------------
24// Configuration
25// ---------------------------------------------------------------------------
26
27/// Configuration for mel spectrogram extraction.
28#[derive(Debug, Clone)]
29pub struct MelConfig {
30    /// FFT size (number of frequency bins before taking positive half).
31    pub n_fft: usize,
32    /// Hop length between STFT frames, in samples.
33    pub hop_length: usize,
34    /// Analysis window length, in samples. May be ≤ `n_fft`.
35    pub win_length: usize,
36    /// Number of mel frequency bands.
37    pub n_mels: usize,
38    /// Expected sample rate of input audio (Hz).
39    pub sample_rate: u32,
40    /// Log-mel normalization mean (subtracted before dividing by `log_std`).
41    pub log_mean: f64,
42    /// Log-mel normalization std (divides after subtracting `log_mean`).
43    pub log_std: f64,
44}
45
46impl MelConfig {
47    /// Config matching Kokoro's style encoder preprocessing.
48    ///
49    /// ```text
50    /// n_fft=2048, hop=300, win=1200, 80 mels, 24 kHz
51    /// norm: (log(1e-5 + mel) - (-4)) / 4
52    /// ```
53    pub fn kokoro() -> Self {
54        Self {
55            n_fft: 2048,
56            hop_length: 300,
57            win_length: 1200,
58            n_mels: 80,
59            sample_rate: 24000,
60            log_mean: -4.0,
61            log_std: 4.0,
62        }
63    }
64
65    /// Number of positive frequency bins: `n_fft / 2 + 1`.
66    pub fn n_freq(&self) -> usize {
67        self.n_fft / 2 + 1
68    }
69}
70
71// ---------------------------------------------------------------------------
72// MelSpectrogram
73// ---------------------------------------------------------------------------
74
75/// Mel spectrogram extractor.
76///
77/// Pre-computes DFT basis vectors, Hann window, and mel filterbank on
78/// construction so that repeated calls to [`compute`](Self::compute) are fast.
79pub struct MelSpectrogram {
80    config: MelConfig,
81    /// DFT cosine basis `[n_freq, n_fft]`.
82    dft_cos: Tensor,
83    /// DFT sine basis `[n_freq, n_fft]`.
84    dft_sin: Tensor,
85    /// Hann window, zero-padded to `n_fft` length.
86    window: Tensor,
87    /// Mel filterbank `[n_mels, n_freq]`.
88    mel_basis: Tensor,
89}
90
91impl MelSpectrogram {
92    /// Create a new mel spectrogram extractor.
93    pub fn new(config: MelConfig, device: &Device) -> TtsResult<Self> {
94        let n_fft = config.n_fft;
95        let n_freq = config.n_freq();
96
97        // ── DFT basis matrices ────────────────────────────────────────
98        let mut cos_data = vec![0f32; n_freq * n_fft];
99        let mut sin_data = vec![0f32; n_freq * n_fft];
100        for k in 0..n_freq {
101            for n in 0..n_fft {
102                let angle = 2.0 * std::f32::consts::PI * (k as f32) * (n as f32) / (n_fft as f32);
103                cos_data[k * n_fft + n] = angle.cos();
104                sin_data[k * n_fft + n] = angle.sin();
105            }
106        }
107        let dft_cos = Tensor::new(cos_data.as_slice(), device)?.reshape((n_freq, n_fft))?;
108        let dft_sin = Tensor::new(sin_data.as_slice(), device)?.reshape((n_freq, n_fft))?;
109
110        // ── Hann window (centre-padded to n_fft) ─────────────────────
111        let mut window_data = vec![0f32; n_fft];
112        let pad_left = (n_fft - config.win_length) / 2;
113        for i in 0..config.win_length {
114            let w = 0.5
115                * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / config.win_length as f32).cos());
116            window_data[pad_left + i] = w;
117        }
118        let window = Tensor::new(window_data.as_slice(), device)?;
119
120        // ── Mel filterbank ────────────────────────────────────────────
121        let mel_basis =
122            Self::build_mel_filterbank(config.n_mels, n_freq, config.sample_rate, device)?;
123
124        Ok(Self {
125            config,
126            dft_cos,
127            dft_sin,
128            window,
129            mel_basis,
130        })
131    }
132
133    /// Compute a log-mel spectrogram from raw PCM audio.
134    ///
135    /// * **Input:** 1-D `[num_samples]` f32 tensor at [`MelConfig::sample_rate`].
136    /// * **Output:** `[1, n_mels, num_frames]` normalised log-mel spectrogram.
137    pub fn compute(&self, audio: &Tensor) -> TtsResult<Tensor> {
138        let audio = audio.to_dtype(DType::F32)?;
139        let n_samples = audio.dim(0)?;
140        let n_fft = self.config.n_fft;
141        let hop = self.config.hop_length;
142
143        // ── Reflect-pad the signal ────────────────────────────────────
144        let pad_len = n_fft / 2;
145        let zeros_l = Tensor::zeros(pad_len, DType::F32, audio.device())?;
146        let zeros_r_len = (n_samples + 2 * pad_len).saturating_sub(n_samples + pad_len);
147        let zeros_r = Tensor::zeros(pad_len.max(zeros_r_len), DType::F32, audio.device())?;
148        let padded = Tensor::cat(&[&zeros_l, &audio, &zeros_r], 0)?;
149        let padded_len = padded.dim(0)?;
150
151        // ── Frame the signal ──────────────────────────────────────────
152        let num_frames = padded_len.saturating_sub(n_fft) / hop + 1;
153        if num_frames == 0 {
154            return Err(TtsError::ModelError(
155                "Audio too short for mel spectrogram extraction".into(),
156            ));
157        }
158
159        let mut frames = Vec::with_capacity(num_frames);
160        for i in 0..num_frames {
161            let start = i * hop;
162            let frame = padded.narrow(0, start, n_fft)?;
163            let windowed = (&frame * &self.window)?;
164            frames.push(windowed);
165        }
166        let frames = Tensor::stack(&frames, 0)?; // [num_frames, n_fft]
167
168        // ── DFT via matrix multiply ───────────────────────────────────
169        let x_real = frames.matmul(&self.dft_cos.t()?)?; // [num_frames, n_freq]
170        let x_imag = frames.matmul(&self.dft_sin.t()?)?;
171
172        // Power spectrum
173        let power = (x_real.sqr()? + x_imag.sqr()?)?; // [num_frames, n_freq]
174
175        // ── Mel filterbank projection ─────────────────────────────────
176        // mel_basis [n_mels, n_freq] × power^T [n_freq, num_frames] → [n_mels, frames]
177        let mel = self.mel_basis.matmul(&power.t()?)?;
178
179        // ── Log compression + normalisation ───────────────────────────
180        let log_mel = (mel + 1e-5)?.log()?;
181        let normalised = log_mel.affine(
182            1.0 / self.config.log_std,
183            -self.config.log_mean / self.config.log_std,
184        )?;
185
186        // [1, n_mels, num_frames]
187        normalised.unsqueeze(0).map_err(TtsError::from)
188    }
189
190    /// Get the config used by this extractor.
191    pub fn config(&self) -> &MelConfig {
192        &self.config
193    }
194
195    // ── Private helpers ───────────────────────────────────────────────
196
197    /// Build a triangular mel filterbank `[n_mels, n_freq]`.
198    fn build_mel_filterbank(
199        n_mels: usize,
200        n_freq: usize,
201        sample_rate: u32,
202        device: &Device,
203    ) -> TtsResult<Tensor> {
204        let sr = sample_rate as f32;
205        let fmax = sr / 2.0;
206
207        let hz_to_mel = |hz: f32| -> f32 { 2595.0 * (1.0 + hz / 700.0).log10() };
208        let mel_to_hz = |m: f32| -> f32 { 700.0 * (10.0f32.powf(m / 2595.0) - 1.0) };
209
210        let mel_min = hz_to_mel(0.0);
211        let mel_max = hz_to_mel(fmax);
212
213        // n_mels + 2 equally-spaced points in mel space
214        let n_points = n_mels + 2;
215        let mel_points: Vec<f32> = (0..n_points)
216            .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_points - 1) as f32)
217            .collect();
218        let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
219
220        // Convert Hz → FFT bin (fractional)
221        let bin_points: Vec<f32> = hz_points
222            .iter()
223            .map(|&hz| hz * (n_freq as f32 - 1.0) * 2.0 / sr)
224            .collect();
225
226        // Triangular filters
227        let mut filters = vec![0f32; n_mels * n_freq];
228        for m in 0..n_mels {
229            let f_left = bin_points[m];
230            let f_center = bin_points[m + 1];
231            let f_right = bin_points[m + 2];
232
233            for k in 0..n_freq {
234                let kf = k as f32;
235                if kf >= f_left && kf <= f_center && f_center > f_left {
236                    filters[m * n_freq + k] = (kf - f_left) / (f_center - f_left);
237                } else if kf > f_center && kf <= f_right && f_right > f_center {
238                    filters[m * n_freq + k] = (f_right - kf) / (f_right - f_center);
239                }
240            }
241        }
242
243        Tensor::new(filters.as_slice(), device)?
244            .reshape((n_mels, n_freq))
245            .map_err(TtsError::from)
246    }
247}
248
249/// Resample audio from `src_rate` to `dst_rate` using linear interpolation.
250///
251/// For voice cloning, the reference audio must match the model's expected
252/// sample rate. This function handles the conversion.
253pub fn resample_linear(samples: &[f32], src_rate: u32, dst_rate: u32) -> Vec<f32> {
254    if src_rate == dst_rate || samples.is_empty() {
255        return samples.to_vec();
256    }
257
258    let ratio = dst_rate as f64 / src_rate as f64;
259    let out_len = (samples.len() as f64 * ratio).ceil() as usize;
260    let mut output = Vec::with_capacity(out_len);
261
262    for i in 0..out_len {
263        let src_idx = i as f64 / ratio;
264        let idx_floor = src_idx.floor() as usize;
265        let frac = (src_idx - idx_floor as f64) as f32;
266
267        let s0 = samples[idx_floor.min(samples.len() - 1)];
268        let s1 = samples[(idx_floor + 1).min(samples.len() - 1)];
269        output.push(s0 + frac * (s1 - s0));
270    }
271
272    output
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_mel_config_kokoro() {
281        let cfg = MelConfig::kokoro();
282        assert_eq!(cfg.n_fft, 2048);
283        assert_eq!(cfg.n_freq(), 1025);
284        assert_eq!(cfg.n_mels, 80);
285        assert_eq!(cfg.sample_rate, 24000);
286    }
287
288    #[test]
289    fn test_mel_spectrogram_shape() {
290        let device = Device::Cpu;
291        let cfg = MelConfig::kokoro();
292        let mel = MelSpectrogram::new(cfg, &device).unwrap();
293
294        // 1 second of audio at 24kHz
295        let audio = Tensor::zeros(24000, DType::F32, &device).unwrap();
296        let spec = mel.compute(&audio).unwrap();
297
298        assert_eq!(spec.dims()[0], 1); // batch
299        assert_eq!(spec.dims()[1], 80); // n_mels
300                                        // num_frames ≈ (24000 + 2048) / 300 = ~86
301        assert!(spec.dims()[2] > 50);
302    }
303
304    #[test]
305    fn test_mel_filterbank_shape() {
306        let device = Device::Cpu;
307        let fb = MelSpectrogram::build_mel_filterbank(80, 1025, 24000, &device).unwrap();
308        assert_eq!(fb.dims(), &[80, 1025]);
309    }
310
311    #[test]
312    fn test_mel_filterbank_values() {
313        let device = Device::Cpu;
314        let fb = MelSpectrogram::build_mel_filterbank(80, 1025, 24000, &device).unwrap();
315        let data: Vec<Vec<f32>> = fb.to_vec2().unwrap();
316
317        // Each row should have at least some non-zero values (triangular filter)
318        for row in &data {
319            let sum: f32 = row.iter().sum();
320            assert!(sum > 0.0, "Mel filter band has zero energy");
321        }
322    }
323
324    #[test]
325    fn test_resample_identity() {
326        let samples = vec![1.0, 2.0, 3.0, 4.0];
327        let out = resample_linear(&samples, 16000, 16000);
328        assert_eq!(out, samples);
329    }
330
331    #[test]
332    fn test_resample_upsample() {
333        let samples = vec![0.0, 1.0];
334        let out = resample_linear(&samples, 1, 4);
335        assert_eq!(out.len(), 8);
336        // Should interpolate between 0.0 and 1.0
337        assert!((out[0] - 0.0).abs() < 0.01);
338    }
339
340    #[test]
341    fn test_resample_empty() {
342        let out = resample_linear(&[], 16000, 24000);
343        assert!(out.is_empty());
344    }
345}