use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum MelScale {
#[default]
Slaney,
Htk,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PaddingMode {
#[default]
Reflect,
Zero,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MelConfig {
pub n_mels: usize,
pub n_fft: usize,
pub hop_length: usize,
pub sample_rate: u32,
#[serde(default)]
pub mel_scale: MelScale,
#[serde(default)]
pub f_min: f64,
#[serde(default)]
pub f_max: f64,
#[serde(default)]
pub padding: PaddingMode,
#[serde(default)]
pub max_frames: Option<usize>,
#[serde(default = "default_normalize")]
pub normalize: bool,
}
fn default_normalize() -> bool {
true
}
impl Default for MelConfig {
fn default() -> Self {
Self {
n_mels: 80,
n_fft: 400,
hop_length: 160,
sample_rate: 16000,
mel_scale: MelScale::Slaney,
f_min: 0.0,
f_max: 8000.0,
padding: PaddingMode::Reflect,
max_frames: Some(3000),
normalize: true,
}
}
}
impl MelConfig {
pub fn from_preset(preset: &str) -> Option<Self> {
match preset.to_lowercase().as_str() {
"whisper" | "whisper-tiny" | "whisper-base" | "whisper-small" | "whisper-medium" => {
Some(Self::whisper())
}
"whisper-large" | "whisper-large-v2" | "whisper-large-v3" => {
Some(Self::whisper_large())
}
_ => None,
}
}
pub fn whisper() -> Self {
Self {
n_mels: 80,
n_fft: 400,
hop_length: 160,
sample_rate: 16000,
mel_scale: MelScale::Slaney,
f_min: 0.0,
f_max: 8000.0,
padding: PaddingMode::Reflect,
max_frames: Some(3000),
normalize: true,
}
}
pub fn whisper_large() -> Self {
Self {
n_mels: 128,
..Self::whisper()
}
}
pub fn htk_default() -> Self {
Self {
n_mels: 80,
n_fft: 400,
hop_length: 160,
sample_rate: 16000,
mel_scale: MelScale::Htk,
f_min: 0.0,
f_max: 0.0, padding: PaddingMode::Zero,
max_frames: None,
normalize: true,
}
}
pub fn effective_f_max(&self) -> f64 {
if self.f_max <= 0.0 {
self.sample_rate as f64 / 2.0
} else {
self.f_max
}
}
pub fn pad_size(&self) -> usize {
match self.padding {
PaddingMode::Reflect | PaddingMode::Zero => (self.n_fft - 1) / 2 + 1,
PaddingMode::None => 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_preset_whisper() {
let config = MelConfig::from_preset("whisper").unwrap();
assert_eq!(config.n_mels, 80);
assert_eq!(config.mel_scale, MelScale::Slaney);
assert_eq!(config.max_frames, Some(3000));
}
#[test]
fn test_preset_whisper_large() {
let config = MelConfig::from_preset("whisper-large").unwrap();
assert_eq!(config.n_mels, 128);
assert_eq!(config.mel_scale, MelScale::Slaney);
}
#[test]
fn test_preset_unknown() {
assert!(MelConfig::from_preset("unknown").is_none());
}
#[test]
fn test_effective_f_max() {
let config = MelConfig {
sample_rate: 16000,
f_max: 0.0,
..Default::default()
};
assert_eq!(config.effective_f_max(), 8000.0);
let config2 = MelConfig {
sample_rate: 16000,
f_max: 4000.0,
..Default::default()
};
assert_eq!(config2.effective_f_max(), 4000.0);
}
#[test]
fn test_serde_roundtrip() {
let config = MelConfig::whisper();
let json = serde_json::to_string(&config).unwrap();
let parsed: MelConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.n_mels, config.n_mels);
assert_eq!(parsed.mel_scale, config.mel_scale);
}
}