use candle_core::{DType, Device, Tensor};
use crate::error::{TtsError, TtsResult};
#[derive(Debug, Clone)]
pub struct MelConfig {
pub n_fft: usize,
pub hop_length: usize,
pub win_length: usize,
pub n_mels: usize,
pub sample_rate: u32,
pub log_mean: f64,
pub log_std: f64,
}
impl MelConfig {
pub fn kokoro() -> Self {
Self {
n_fft: 2048,
hop_length: 300,
win_length: 1200,
n_mels: 80,
sample_rate: 24000,
log_mean: -4.0,
log_std: 4.0,
}
}
pub fn n_freq(&self) -> usize {
self.n_fft / 2 + 1
}
}
pub struct MelSpectrogram {
config: MelConfig,
dft_cos: Tensor,
dft_sin: Tensor,
window: Tensor,
mel_basis: Tensor,
}
impl MelSpectrogram {
pub fn new(config: MelConfig, device: &Device) -> TtsResult<Self> {
let n_fft = config.n_fft;
let n_freq = config.n_freq();
let mut cos_data = vec![0f32; n_freq * n_fft];
let mut sin_data = vec![0f32; n_freq * n_fft];
for k in 0..n_freq {
for n in 0..n_fft {
let angle = 2.0 * std::f32::consts::PI * (k as f32) * (n as f32) / (n_fft as f32);
cos_data[k * n_fft + n] = angle.cos();
sin_data[k * n_fft + n] = angle.sin();
}
}
let dft_cos = Tensor::new(cos_data.as_slice(), device)?.reshape((n_freq, n_fft))?;
let dft_sin = Tensor::new(sin_data.as_slice(), device)?.reshape((n_freq, n_fft))?;
let mut window_data = vec![0f32; n_fft];
let pad_left = (n_fft - config.win_length) / 2;
for i in 0..config.win_length {
let w = 0.5
* (1.0 - (2.0 * std::f32::consts::PI * i as f32 / config.win_length as f32).cos());
window_data[pad_left + i] = w;
}
let window = Tensor::new(window_data.as_slice(), device)?;
let mel_basis =
Self::build_mel_filterbank(config.n_mels, n_freq, config.sample_rate, device)?;
Ok(Self {
config,
dft_cos,
dft_sin,
window,
mel_basis,
})
}
pub fn compute(&self, audio: &Tensor) -> TtsResult<Tensor> {
let audio = audio.to_dtype(DType::F32)?;
let n_samples = audio.dim(0)?;
let n_fft = self.config.n_fft;
let hop = self.config.hop_length;
let pad_len = n_fft / 2;
let zeros_l = Tensor::zeros(pad_len, DType::F32, audio.device())?;
let zeros_r_len = (n_samples + 2 * pad_len).saturating_sub(n_samples + pad_len);
let zeros_r = Tensor::zeros(pad_len.max(zeros_r_len), DType::F32, audio.device())?;
let padded = Tensor::cat(&[&zeros_l, &audio, &zeros_r], 0)?;
let padded_len = padded.dim(0)?;
let num_frames = padded_len.saturating_sub(n_fft) / hop + 1;
if num_frames == 0 {
return Err(TtsError::ModelError(
"Audio too short for mel spectrogram extraction".into(),
));
}
let mut frames = Vec::with_capacity(num_frames);
for i in 0..num_frames {
let start = i * hop;
let frame = padded.narrow(0, start, n_fft)?;
let windowed = (&frame * &self.window)?;
frames.push(windowed);
}
let frames = Tensor::stack(&frames, 0)?;
let x_real = frames.matmul(&self.dft_cos.t()?)?; let x_imag = frames.matmul(&self.dft_sin.t()?)?;
let power = (x_real.sqr()? + x_imag.sqr()?)?;
let mel = self.mel_basis.matmul(&power.t()?)?;
let log_mel = (mel + 1e-5)?.log()?;
let normalised = log_mel.affine(
1.0 / self.config.log_std,
-self.config.log_mean / self.config.log_std,
)?;
normalised.unsqueeze(0).map_err(TtsError::from)
}
pub fn config(&self) -> &MelConfig {
&self.config
}
fn build_mel_filterbank(
n_mels: usize,
n_freq: usize,
sample_rate: u32,
device: &Device,
) -> TtsResult<Tensor> {
let sr = sample_rate as f32;
let fmax = sr / 2.0;
let hz_to_mel = |hz: f32| -> f32 { 2595.0 * (1.0 + hz / 700.0).log10() };
let mel_to_hz = |m: f32| -> f32 { 700.0 * (10.0f32.powf(m / 2595.0) - 1.0) };
let mel_min = hz_to_mel(0.0);
let mel_max = hz_to_mel(fmax);
let n_points = n_mels + 2;
let mel_points: Vec<f32> = (0..n_points)
.map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_points - 1) as f32)
.collect();
let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let bin_points: Vec<f32> = hz_points
.iter()
.map(|&hz| hz * (n_freq as f32 - 1.0) * 2.0 / sr)
.collect();
let mut filters = vec![0f32; n_mels * n_freq];
for m in 0..n_mels {
let f_left = bin_points[m];
let f_center = bin_points[m + 1];
let f_right = bin_points[m + 2];
for k in 0..n_freq {
let kf = k as f32;
if kf >= f_left && kf <= f_center && f_center > f_left {
filters[m * n_freq + k] = (kf - f_left) / (f_center - f_left);
} else if kf > f_center && kf <= f_right && f_right > f_center {
filters[m * n_freq + k] = (f_right - kf) / (f_right - f_center);
}
}
}
Tensor::new(filters.as_slice(), device)?
.reshape((n_mels, n_freq))
.map_err(TtsError::from)
}
}
pub fn resample_linear(samples: &[f32], src_rate: u32, dst_rate: u32) -> Vec<f32> {
if src_rate == dst_rate || samples.is_empty() {
return samples.to_vec();
}
let ratio = dst_rate as f64 / src_rate as f64;
let out_len = (samples.len() as f64 * ratio).ceil() as usize;
let mut output = Vec::with_capacity(out_len);
for i in 0..out_len {
let src_idx = i as f64 / ratio;
let idx_floor = src_idx.floor() as usize;
let frac = (src_idx - idx_floor as f64) as f32;
let s0 = samples[idx_floor.min(samples.len() - 1)];
let s1 = samples[(idx_floor + 1).min(samples.len() - 1)];
output.push(s0 + frac * (s1 - s0));
}
output
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mel_config_kokoro() {
let cfg = MelConfig::kokoro();
assert_eq!(cfg.n_fft, 2048);
assert_eq!(cfg.n_freq(), 1025);
assert_eq!(cfg.n_mels, 80);
assert_eq!(cfg.sample_rate, 24000);
}
#[test]
fn test_mel_spectrogram_shape() {
let device = Device::Cpu;
let cfg = MelConfig::kokoro();
let mel = MelSpectrogram::new(cfg, &device).unwrap();
let audio = Tensor::zeros(24000, DType::F32, &device).unwrap();
let spec = mel.compute(&audio).unwrap();
assert_eq!(spec.dims()[0], 1); assert_eq!(spec.dims()[1], 80); assert!(spec.dims()[2] > 50);
}
#[test]
fn test_mel_filterbank_shape() {
let device = Device::Cpu;
let fb = MelSpectrogram::build_mel_filterbank(80, 1025, 24000, &device).unwrap();
assert_eq!(fb.dims(), &[80, 1025]);
}
#[test]
fn test_mel_filterbank_values() {
let device = Device::Cpu;
let fb = MelSpectrogram::build_mel_filterbank(80, 1025, 24000, &device).unwrap();
let data: Vec<Vec<f32>> = fb.to_vec2().unwrap();
for row in &data {
let sum: f32 = row.iter().sum();
assert!(sum > 0.0, "Mel filter band has zero energy");
}
}
#[test]
fn test_resample_identity() {
let samples = vec![1.0, 2.0, 3.0, 4.0];
let out = resample_linear(&samples, 16000, 16000);
assert_eq!(out, samples);
}
#[test]
fn test_resample_upsample() {
let samples = vec![0.0, 1.0];
let out = resample_linear(&samples, 1, 4);
assert_eq!(out.len(), 8);
assert!((out[0] - 0.0).abs() < 0.01);
}
#[test]
fn test_resample_empty() {
let out = resample_linear(&[], 16000, 24000);
assert!(out.is_empty());
}
}