1use candle_core::{DType, Device, Tensor};
20
21use crate::error::{TtsError, TtsResult};
22
23#[derive(Debug, Clone)]
29pub struct MelConfig {
30 pub n_fft: usize,
32 pub hop_length: usize,
34 pub win_length: usize,
36 pub n_mels: usize,
38 pub sample_rate: u32,
40 pub log_mean: f64,
42 pub log_std: f64,
44}
45
46impl MelConfig {
47 pub fn kokoro() -> Self {
54 Self {
55 n_fft: 2048,
56 hop_length: 300,
57 win_length: 1200,
58 n_mels: 80,
59 sample_rate: 24000,
60 log_mean: -4.0,
61 log_std: 4.0,
62 }
63 }
64
65 pub fn n_freq(&self) -> usize {
67 self.n_fft / 2 + 1
68 }
69}
70
71pub struct MelSpectrogram {
80 config: MelConfig,
81 dft_cos: Tensor,
83 dft_sin: Tensor,
85 window: Tensor,
87 mel_basis: Tensor,
89}
90
91impl MelSpectrogram {
92 pub fn new(config: MelConfig, device: &Device) -> TtsResult<Self> {
94 let n_fft = config.n_fft;
95 let n_freq = config.n_freq();
96
97 let mut cos_data = vec![0f32; n_freq * n_fft];
99 let mut sin_data = vec![0f32; n_freq * n_fft];
100 for k in 0..n_freq {
101 for n in 0..n_fft {
102 let angle = 2.0 * std::f32::consts::PI * (k as f32) * (n as f32) / (n_fft as f32);
103 cos_data[k * n_fft + n] = angle.cos();
104 sin_data[k * n_fft + n] = angle.sin();
105 }
106 }
107 let dft_cos = Tensor::new(cos_data.as_slice(), device)?.reshape((n_freq, n_fft))?;
108 let dft_sin = Tensor::new(sin_data.as_slice(), device)?.reshape((n_freq, n_fft))?;
109
110 let mut window_data = vec![0f32; n_fft];
112 let pad_left = (n_fft - config.win_length) / 2;
113 for i in 0..config.win_length {
114 let w = 0.5
115 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / config.win_length as f32).cos());
116 window_data[pad_left + i] = w;
117 }
118 let window = Tensor::new(window_data.as_slice(), device)?;
119
120 let mel_basis =
122 Self::build_mel_filterbank(config.n_mels, n_freq, config.sample_rate, device)?;
123
124 Ok(Self {
125 config,
126 dft_cos,
127 dft_sin,
128 window,
129 mel_basis,
130 })
131 }
132
133 pub fn compute(&self, audio: &Tensor) -> TtsResult<Tensor> {
138 let audio = audio.to_dtype(DType::F32)?;
139 let n_samples = audio.dim(0)?;
140 let n_fft = self.config.n_fft;
141 let hop = self.config.hop_length;
142
143 let pad_len = n_fft / 2;
145 let zeros_l = Tensor::zeros(pad_len, DType::F32, audio.device())?;
146 let zeros_r_len = (n_samples + 2 * pad_len).saturating_sub(n_samples + pad_len);
147 let zeros_r = Tensor::zeros(pad_len.max(zeros_r_len), DType::F32, audio.device())?;
148 let padded = Tensor::cat(&[&zeros_l, &audio, &zeros_r], 0)?;
149 let padded_len = padded.dim(0)?;
150
151 let num_frames = padded_len.saturating_sub(n_fft) / hop + 1;
153 if num_frames == 0 {
154 return Err(TtsError::ModelError(
155 "Audio too short for mel spectrogram extraction".into(),
156 ));
157 }
158
159 let mut frames = Vec::with_capacity(num_frames);
160 for i in 0..num_frames {
161 let start = i * hop;
162 let frame = padded.narrow(0, start, n_fft)?;
163 let windowed = (&frame * &self.window)?;
164 frames.push(windowed);
165 }
166 let frames = Tensor::stack(&frames, 0)?; let x_real = frames.matmul(&self.dft_cos.t()?)?; let x_imag = frames.matmul(&self.dft_sin.t()?)?;
171
172 let power = (x_real.sqr()? + x_imag.sqr()?)?; let mel = self.mel_basis.matmul(&power.t()?)?;
178
179 let log_mel = (mel + 1e-5)?.log()?;
181 let normalised = log_mel.affine(
182 1.0 / self.config.log_std,
183 -self.config.log_mean / self.config.log_std,
184 )?;
185
186 normalised.unsqueeze(0).map_err(TtsError::from)
188 }
189
190 pub fn config(&self) -> &MelConfig {
192 &self.config
193 }
194
195 fn build_mel_filterbank(
199 n_mels: usize,
200 n_freq: usize,
201 sample_rate: u32,
202 device: &Device,
203 ) -> TtsResult<Tensor> {
204 let sr = sample_rate as f32;
205 let fmax = sr / 2.0;
206
207 let hz_to_mel = |hz: f32| -> f32 { 2595.0 * (1.0 + hz / 700.0).log10() };
208 let mel_to_hz = |m: f32| -> f32 { 700.0 * (10.0f32.powf(m / 2595.0) - 1.0) };
209
210 let mel_min = hz_to_mel(0.0);
211 let mel_max = hz_to_mel(fmax);
212
213 let n_points = n_mels + 2;
215 let mel_points: Vec<f32> = (0..n_points)
216 .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_points - 1) as f32)
217 .collect();
218 let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
219
220 let bin_points: Vec<f32> = hz_points
222 .iter()
223 .map(|&hz| hz * (n_freq as f32 - 1.0) * 2.0 / sr)
224 .collect();
225
226 let mut filters = vec![0f32; n_mels * n_freq];
228 for m in 0..n_mels {
229 let f_left = bin_points[m];
230 let f_center = bin_points[m + 1];
231 let f_right = bin_points[m + 2];
232
233 for k in 0..n_freq {
234 let kf = k as f32;
235 if kf >= f_left && kf <= f_center && f_center > f_left {
236 filters[m * n_freq + k] = (kf - f_left) / (f_center - f_left);
237 } else if kf > f_center && kf <= f_right && f_right > f_center {
238 filters[m * n_freq + k] = (f_right - kf) / (f_right - f_center);
239 }
240 }
241 }
242
243 Tensor::new(filters.as_slice(), device)?
244 .reshape((n_mels, n_freq))
245 .map_err(TtsError::from)
246 }
247}
248
249pub fn resample_linear(samples: &[f32], src_rate: u32, dst_rate: u32) -> Vec<f32> {
254 if src_rate == dst_rate || samples.is_empty() {
255 return samples.to_vec();
256 }
257
258 let ratio = dst_rate as f64 / src_rate as f64;
259 let out_len = (samples.len() as f64 * ratio).ceil() as usize;
260 let mut output = Vec::with_capacity(out_len);
261
262 for i in 0..out_len {
263 let src_idx = i as f64 / ratio;
264 let idx_floor = src_idx.floor() as usize;
265 let frac = (src_idx - idx_floor as f64) as f32;
266
267 let s0 = samples[idx_floor.min(samples.len() - 1)];
268 let s1 = samples[(idx_floor + 1).min(samples.len() - 1)];
269 output.push(s0 + frac * (s1 - s0));
270 }
271
272 output
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_mel_config_kokoro() {
281 let cfg = MelConfig::kokoro();
282 assert_eq!(cfg.n_fft, 2048);
283 assert_eq!(cfg.n_freq(), 1025);
284 assert_eq!(cfg.n_mels, 80);
285 assert_eq!(cfg.sample_rate, 24000);
286 }
287
288 #[test]
289 fn test_mel_spectrogram_shape() {
290 let device = Device::Cpu;
291 let cfg = MelConfig::kokoro();
292 let mel = MelSpectrogram::new(cfg, &device).unwrap();
293
294 let audio = Tensor::zeros(24000, DType::F32, &device).unwrap();
296 let spec = mel.compute(&audio).unwrap();
297
298 assert_eq!(spec.dims()[0], 1); assert_eq!(spec.dims()[1], 80); assert!(spec.dims()[2] > 50);
302 }
303
304 #[test]
305 fn test_mel_filterbank_shape() {
306 let device = Device::Cpu;
307 let fb = MelSpectrogram::build_mel_filterbank(80, 1025, 24000, &device).unwrap();
308 assert_eq!(fb.dims(), &[80, 1025]);
309 }
310
311 #[test]
312 fn test_mel_filterbank_values() {
313 let device = Device::Cpu;
314 let fb = MelSpectrogram::build_mel_filterbank(80, 1025, 24000, &device).unwrap();
315 let data: Vec<Vec<f32>> = fb.to_vec2().unwrap();
316
317 for row in &data {
319 let sum: f32 = row.iter().sum();
320 assert!(sum > 0.0, "Mel filter band has zero energy");
321 }
322 }
323
324 #[test]
325 fn test_resample_identity() {
326 let samples = vec![1.0, 2.0, 3.0, 4.0];
327 let out = resample_linear(&samples, 16000, 16000);
328 assert_eq!(out, samples);
329 }
330
331 #[test]
332 fn test_resample_upsample() {
333 let samples = vec![0.0, 1.0];
334 let out = resample_linear(&samples, 1, 4);
335 assert_eq!(out.len(), 8);
336 assert!((out[0] - 0.0).abs() < 0.01);
338 }
339
340 #[test]
341 fn test_resample_empty() {
342 let out = resample_linear(&[], 16000, 24000);
343 assert!(out.is_empty());
344 }
345}