use std::path::Path;
use smol_str::format_smolstr;
use crate::{
array::Array,
audio::{dsp, io as audio_io},
error::{
EmptyInputPayload, Error, LengthMismatchPayload, OutOfRangePayload, RankMismatchPayload,
Result, try_extend_from_slice,
},
lm::{
cache::KvCache,
generate::{
FinishReason, GenConfig, GenStep, LogitsProcessor, Sampler, make_logits_processors,
make_sampler,
},
},
ops,
};
pub const DEFAULT_MAX_AUDIO_SECONDS: f32 = 30.0;
pub struct SttGenConfig {
lm: GenConfig,
auto_resample: bool,
max_audio_seconds: f32,
}
impl SttGenConfig {
pub fn new(lm: GenConfig, auto_resample: bool, max_audio_seconds: f32) -> Self {
Self {
lm,
auto_resample,
max_audio_seconds,
}
}
#[inline(always)]
pub fn lm(&self) -> &GenConfig {
&self.lm
}
#[inline(always)]
pub fn lm_mut(&mut self) -> &mut GenConfig {
&mut self.lm
}
#[inline(always)]
pub fn auto_resample(&self) -> bool {
self.auto_resample
}
#[inline(always)]
pub fn max_audio_seconds(&self) -> f32 {
self.max_audio_seconds
}
pub fn into_lm(self) -> GenConfig {
self.lm
}
pub fn with_lm(self, lm: GenConfig) -> Self {
Self { lm, ..self }
}
pub fn with_auto_resample(self, auto_resample: bool) -> Self {
Self {
auto_resample,
..self
}
}
pub fn with_max_audio_seconds(self, max_audio_seconds: f32) -> Self {
Self {
max_audio_seconds,
..self
}
}
}
impl Default for SttGenConfig {
fn default() -> Self {
Self {
lm: GenConfig::default(),
auto_resample: true,
max_audio_seconds: DEFAULT_MAX_AUDIO_SECONDS,
}
}
}
fn audio_path_to_mel<M: super::model::Model>(
model: &M,
audio_path: &Path,
cfg: &SttGenConfig,
) -> Result<Array> {
if !cfg.max_audio_seconds().is_finite() || cfg.max_audio_seconds() <= 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"stt_generate: max_audio_seconds",
"must be a finite value > 0",
format!("{}", cfg.max_audio_seconds()),
)));
}
let mc = model.mel_config();
let (samples, src_sr) =
audio_io::load_audio_with_max_seconds(audio_path, cfg.max_audio_seconds())?;
let cap_f64 = f64::from(cfg.max_audio_seconds());
let src_duration = samples.len() as f64 / f64::from(src_sr);
if src_duration > cap_f64 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"stt_generate: audio duration (rejected before resample / mel-spec allocation)",
"must be <= `max_audio_seconds` cap",
format_smolstr!(
"duration={src_duration:.3}s, src_sample_rate={src_sr}, samples={}, cap={:.3}s",
samples.len(),
cfg.max_audio_seconds()
),
)));
}
let target_sr = mc.sample_rate();
let samples: Vec<f32> = if src_sr == target_sr {
samples
} else if cfg.auto_resample() {
audio_io::resample_linear(&samples, src_sr, target_sr)?
} else {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"stt_generate: audio sample rate (auto_resample=false)",
"must equal model.mel_config().sample_rate (or enable auto_resample / pre-resample)",
format_smolstr!("src_sr={src_sr}, target_sr={target_sr}"),
)));
};
if samples.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"stt_generate: audio input (0 samples after load/resample; \
`model.encode_audio` requires at least one mel frame — provide a non-empty WAV)",
)));
}
let n_samples = i32::try_from(samples.len()).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"stt_generate: samples.len()",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{}", samples.len()),
))
})?;
let samples_arr = Array::from_slice::<f32>(&samples, &[n_samples])?;
dsp::log_mel_spectrogram_with(
&samples_arr,
mc.n_fft(),
mc.hop_length(),
mc.win_length(),
mc.n_mels(),
mc.sample_rate(),
mc.f_min(),
mc.f_max(),
mc.log_floor(),
)
}
pub struct SttGenerator<'a, M> {
model: &'a M,
encoder_states: Array,
cache: Vec<Box<dyn KvCache>>,
sampler: Sampler,
processors: Vec<LogitsProcessor>,
history: Vec<u32>,
last: u32,
produced: usize,
max_tokens: usize,
eos: Vec<u32>,
done: bool,
}
impl<M: super::model::Model> SttGenerator<'_, M> {
fn step(&mut self) -> Result<GenStep> {
let logits = self
.model
.decode_step(self.last, &self.encoder_states, &mut self.cache)?;
let shape = logits.shape();
if shape.len() != 2 {
let actual = shape.len() as u32;
return Err(Error::RankMismatch(RankMismatchPayload::new(
"stt_generate: `decode_step` returned logits must be rank 2 (shape [1, V])",
actual,
shape,
)));
}
if shape[0] != 1 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"stt_generate: `decode_step` returned logits batch dim (must be 1; only single-utterance decoding is supported)",
1,
shape[0],
)));
}
if shape[1] == 0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"stt_generate: `decode_step` returned logits vocab dim (V == 0)",
)));
}
let mut logits = logits;
if !self.processors.is_empty() {
try_extend_from_slice(&mut self.history, &[self.last])?;
for p in &self.processors {
logits = p.apply(&self.history, &logits)?;
}
}
let lse = ops::reduction::logsumexp(&logits, true)?;
let logprobs = ops::arithmetic::subtract(&logits, &lse)?;
let mut sampled = self.sampler.sample(&logprobs)?;
let token: u32 = sampled.item::<u32>()?;
let logprobs = ops::shape::squeeze_axes(&logprobs, &[0])?;
Ok(GenStep {
token,
logprobs: Some(logprobs),
step_index: self.produced,
finish_reason: None,
})
}
}
impl<M: super::model::Model> Iterator for SttGenerator<'_, M> {
type Item = Result<GenStep>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
if self.produced >= self.max_tokens {
self.done = true;
return None;
}
match self.step() {
Ok(mut step) => {
self.produced += 1;
let token = step.token;
self.last = token;
if self.eos.contains(&token) {
self.done = true;
step.finish_reason = Some(FinishReason::Eos);
}
Some(Ok(step))
}
Err(e) => {
self.done = true;
Some(Err(e))
}
}
}
}
pub fn stt_generate<'a, M: super::model::Model>(
model: &'a M,
audio_path: &Path,
cache: Vec<Box<dyn KvCache>>,
cfg: SttGenConfig,
) -> Result<SttGenerator<'a, M>> {
cfg.lm().validate()?;
let (sampler, processors) = {
let lm = cfg.lm();
let sampler = make_sampler(
lm.temp,
lm.top_p,
lm.min_p,
lm.min_tokens_to_keep,
lm.top_k,
lm.xtc_probability,
lm.xtc_threshold,
&lm.xtc_special_tokens,
lm.seed,
)?;
let processors = make_logits_processors(
&lm.logit_bias,
lm.repetition_penalty,
lm.repetition_context_size,
lm.presence_penalty,
lm.presence_context_size,
lm.frequency_penalty,
lm.frequency_context_size,
)?;
(sampler, processors)
};
let mel = audio_path_to_mel(model, audio_path, &cfg)?;
let encoder_states = model.encode_audio(&mel)?;
let lm = cfg.into_lm();
let max_tokens = lm.max_tokens;
let cfg_eos = lm.eos;
let model_eos = model.eos_token();
let mut eos: Vec<u32> = cfg_eos;
if !eos.contains(&model_eos) {
eos.push(model_eos);
}
Ok(SttGenerator {
model,
encoder_states,
cache,
sampler,
processors,
history: Vec::new(),
last: model.bos_token(),
produced: 0,
max_tokens,
eos,
done: false,
})
}
pub fn encode_audio_file<M: super::model::Model>(
model: &M,
audio_path: &Path,
cfg: &SttGenConfig,
) -> Result<Array> {
let mel = audio_path_to_mel(model, audio_path, cfg)?;
model.encode_audio(&mel)
}