use crate::{
array::Array,
error::{Error, InvariantViolationPayload, Result},
lm::cache::KvCache,
};
pub trait Model: crate::lm::model::Model {
fn encode_audio(&self, mel: &Array) -> Result<Array>;
fn decode_step(
&self,
token: u32,
encoder_states: &Array,
cache: &mut [Box<dyn KvCache>],
) -> Result<Array> {
let _ = (token, encoder_states, cache);
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"SttModel::decode_step",
"must be overridden by per-model implementation",
)))
}
fn mel_config(&self) -> MelConfig {
MelConfig::whisper_default()
}
fn bos_token(&self) -> u32;
fn eos_token(&self) -> u32;
}
#[derive(Debug, Clone, Copy)]
pub struct MelConfig {
n_fft: usize,
hop_length: usize,
win_length: Option<usize>,
n_mels: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
log_floor: crate::audio::dsp::LogFloor,
}
impl MelConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
n_fft: usize,
hop_length: usize,
win_length: Option<usize>,
n_mels: usize,
sample_rate: u32,
f_min: f32,
f_max: Option<f32>,
log_floor: crate::audio::dsp::LogFloor,
) -> Self {
Self {
n_fft,
hop_length,
win_length,
n_mels,
sample_rate,
f_min,
f_max,
log_floor,
}
}
pub const fn whisper_default() -> Self {
Self {
n_fft: 400,
hop_length: 160,
win_length: None,
n_mels: 80,
sample_rate: 16_000,
f_min: 0.0,
f_max: None,
log_floor: crate::audio::dsp::LogFloor::Whisper,
}
}
#[inline(always)]
pub fn n_fft(&self) -> usize {
self.n_fft
}
#[inline(always)]
pub fn hop_length(&self) -> usize {
self.hop_length
}
#[inline(always)]
pub fn win_length(&self) -> Option<usize> {
self.win_length
}
#[inline(always)]
pub fn n_mels(&self) -> usize {
self.n_mels
}
#[inline(always)]
pub fn sample_rate(&self) -> u32 {
self.sample_rate
}
#[inline(always)]
pub fn f_min(&self) -> f32 {
self.f_min
}
#[inline(always)]
pub fn f_max(&self) -> Option<f32> {
self.f_max
}
#[inline(always)]
pub fn log_floor(&self) -> crate::audio::dsp::LogFloor {
self.log_floor
}
#[inline(always)]
pub fn with_n_fft(self, n_fft: usize) -> Self {
Self { n_fft, ..self }
}
#[inline(always)]
pub fn with_hop_length(self, hop_length: usize) -> Self {
Self { hop_length, ..self }
}
#[inline(always)]
pub fn with_win_length(self, win_length: Option<usize>) -> Self {
Self { win_length, ..self }
}
#[inline(always)]
pub fn with_n_mels(self, n_mels: usize) -> Self {
Self { n_mels, ..self }
}
#[inline(always)]
pub fn with_sample_rate(self, sample_rate: u32) -> Self {
Self {
sample_rate,
..self
}
}
#[inline(always)]
pub fn with_f_min(self, f_min: f32) -> Self {
Self { f_min, ..self }
}
#[inline(always)]
pub fn with_f_max(self, f_max: Option<f32>) -> Self {
Self { f_max, ..self }
}
#[inline(always)]
pub fn with_log_floor(self, log_floor: crate::audio::dsp::LogFloor) -> Self {
Self { log_floor, ..self }
}
}