use bon::bon;
use snafu::{ResultExt, Snafu};
use svod_arch::ctc::CtcDecoder;
use svod_arch::rnnt::{RnntDecoder, RnntOpts};
use svod_tensor::PrepareConfig;
pub use svod_arch::rnnt::Word;
use crate::audio::{AudioChunk, EncoderBounds, MelConfig, MelSpectrogram, Splitter};
use crate::gigaam::SubsamplingMode;
use crate::gigaam::ctc::CtcHeadJit;
use crate::gigaam::jit::GigaAmEncoderJit;
use crate::gigaam::model::{GigaAm, Head};
use crate::gigaam::rnnt::RnntStepBackend;
use crate::jit::InputSpec;
#[derive(Clone, Debug)]
pub struct TranscribeOpts {
pub word_timestamps: bool,
pub beam_decode: bool,
pub max_scores_mib: usize,
}
impl Default for TranscribeOpts {
fn default() -> Self {
Self::builder().build()
}
}
#[bon]
impl TranscribeOpts {
#[builder]
pub fn builder(
#[builder(default = std::env::var("SVOD_TIMESTAMPS").as_deref() == Ok("1"))] word_timestamps: bool,
#[builder(default = std::env::var("SVOD_BEAM_DECODE").as_deref() == Ok("1"))] beam_decode: bool,
#[builder(default = std::env::var("SVOD_MAX_SCORES_MIB").ok().and_then(|s| s.parse().ok()).unwrap_or(256))]
max_scores_mib: usize,
) -> Self {
Self { word_timestamps, beam_decode, max_scores_mib }
}
pub fn from_env() -> Self {
Self::builder().build()
}
}
#[derive(Clone, Debug)]
pub struct TranscribeResult {
pub text: String,
pub chunks: Vec<ChunkResult>,
}
impl TranscribeResult {
pub fn words(&self) -> impl Iterator<Item = Word> + '_ {
self.chunks.iter().flat_map(|c| {
let offset = c.start_sec;
c.words.iter().flatten().map(move |w| Word {
text: w.text.clone(),
start: w.start + offset,
end: w.end + offset,
})
})
}
}
#[derive(Clone, Debug)]
pub struct ChunkResult {
pub start_sec: f32,
pub end_sec: f32,
pub text: String,
pub words: Option<Vec<Word>>,
}
#[allow(clippy::large_enum_variant)]
pub(crate) enum HeadDecoder {
Ctc { jit: CtcHeadJit, decoder: CtcDecoder },
Rnnt { backend: RnntStepBackend, decoder: RnntDecoder, sentencepiece: bool },
}
pub(crate) fn ctc_frames_to_words(text: &str, frames: &[usize], frame_shift: f32) -> Vec<Word> {
let mut words: Vec<Word> = Vec::new();
let mut current = String::new();
let mut first_frame = 0usize;
let mut last_frame = 0usize;
let commit = |words: &mut Vec<Word>, current: &mut String, first: usize, last: usize| {
if !current.is_empty() {
words.push(Word {
text: std::mem::take(current),
start: first as f32 * frame_shift,
end: (last + 1) as f32 * frame_shift,
});
}
};
for (ch, &frame) in text.chars().zip(frames.iter()) {
if ch == ' ' {
commit(&mut words, &mut current, first_frame, last_frame);
continue;
}
if current.is_empty() {
first_frame = frame;
}
current.push(ch);
last_frame = frame;
}
commit(&mut words, &mut current, first_frame, last_frame);
words
}
fn transpose_dt_to_td(src: &[f32], d_model: usize, t_exec_sub: usize, actual_sub: usize) -> Vec<f32> {
let mut out = vec![0.0_f32; actual_sub * d_model];
for t in 0..actual_sub {
for d in 0..d_model {
out[t * d_model + d] = src[d * t_exec_sub + t];
}
}
out
}
fn rnnt_decode_err<E: std::error::Error + 'static>(
e: svod_arch::rnnt::RnntDecodeError<crate::jit::JitError>,
) -> TranscribeError<E> {
TranscribeError::RnntDecode { source: Box::new(e) }
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub enum TranscribeError<E: std::error::Error + 'static> {
#[snafu(display("splitter: {source}"))]
Splitter { source: E },
#[snafu(display("{source}"))]
Jit {
#[snafu(source(from(crate::jit::JitError, Box::new)))]
source: Box<crate::jit::JitError>,
},
#[snafu(display("{source}"))]
CtcDecode { source: svod_arch::ctc::DecodeError },
#[snafu(display("{source}"))]
RnntDecode { source: Box<svod_arch::rnnt::RnntDecodeError<crate::jit::JitError>> },
#[snafu(display("{source}"))]
Model {
#[snafu(source(from(crate::gigaam::error::Error, Box::new)))]
source: Box<crate::gigaam::error::Error>,
},
#[snafu(display("{source}"))]
Tensor {
#[snafu(source(from(svod_tensor::error::Error, Box::new)))]
source: Box<svod_tensor::error::Error>,
},
#[snafu(display("{source}"))]
Device {
#[snafu(source(from(svod_device::error::Error, Box::new)))]
source: Box<svod_device::error::Error>,
},
#[snafu(display("WAV is {wav_sr} Hz, model expects {model_sr} Hz (resample first)"))]
SampleRateMismatch { wav_sr: u32, model_sr: u32 },
#[snafu(display("chunk {idx} length {samples} samples exceeds encoder capacity {max_samples} samples"))]
ChunkExceedsCapacity { idx: usize, samples: usize, max_samples: usize },
#[snafu(display("chunk {idx} end {end_sample} exceeds waveform length {waveform_len}"))]
ChunkOutOfRange { idx: usize, end_sample: usize, waveform_len: usize },
}
pub struct Transcriber<S: Splitter> {
model: GigaAm,
opts: TranscribeOpts,
splitter: S,
mel: MelSpectrogram,
head_decoder: HeadDecoder,
encoder_jit: GigaAmEncoderJit,
max_batch: usize,
max_t_mel: usize,
}
impl<S: Splitter> Transcriber<S> {
pub fn new(model: GigaAm, splitter: S, opts: TranscribeOpts) -> Result<Self, TranscribeError<S::Error>> {
let mel = MelSpectrogram::new(&MelConfig {
sample_rate: model.config.sample_rate,
n_fft: model.config.n_fft,
hop_length: model.config.hop_length,
win_length: model.config.win_length,
n_mels: model.config.n_mels,
center: model.config.mel_center,
});
let subsampling_factor = model.config.subsampling_factor;
let hop_length = model.config.hop_length;
let model_bounds = EncoderBounds {
sample_rate: model.config.sample_rate as u32,
hop_length,
subsampling_factor,
max_mel_frames: model.config.max_mel_frames,
};
let chunk_samples_cap = splitter.max_chunk_samples(&model_bounds).min(model_bounds.max_samples());
let chunk_mel = (chunk_samples_cap / hop_length).saturating_add(2 * subsampling_factor);
let max_t_mel = chunk_mel.max(1).next_power_of_two().min(model.config.max_mel_frames).max(subsampling_factor);
let t_sub_max = (max_t_mel / subsampling_factor).max(1);
let scores_dtype_bytes = model.encoder.input_dtype().bytes();
let bytes_per_batch = model.config.n_heads * t_sub_max * t_sub_max * scores_dtype_bytes;
let target_scores_bytes = opts.max_scores_mib * 1024 * 1024;
let max_batch_by_memory = (target_scores_bytes / bytes_per_batch.max(1)).max(1);
let max_batch = max_batch_by_memory.min(model.config.max_batch_size);
let prepare_config = PrepareConfig::from_env();
let mut encoder_jit = GigaAmEncoderJit::new(model.clone()).with_b_bound(max_batch).with_t_bound(max_t_mel);
encoder_jit
.prepare_with_config(
InputSpec::f32(&[max_batch, model.config.n_mels, max_t_mel]),
InputSpec::i32(&[max_batch]),
&prepare_config,
)
.context(JitSnafu)?;
let head_decoder = match &model.head {
Head::Ctc(_) => {
let decoder = if opts.beam_decode {
match &model.config.decoder {
CtcDecoder::Greedy(g) => CtcDecoder::Beam(Box::new(svod_arch::ctc::BeamDecoder::new(
g.vocabulary().to_vec(),
svod_arch::ctc::BeamOpts::default(),
))),
other => other.clone(),
}
} else {
model.config.decoder.clone()
};
let subs_kernel_size = match model.config.subsampling_mode {
SubsamplingMode::Conv1d => model.config.subs_kernel_size,
SubsamplingMode::Conv2d => 3,
};
let max_t_sub = subs_output_length(subs_kernel_size, max_t_mel);
let mut jit = CtcHeadJit::new(model.clone()).with_b_bound(max_batch).with_t_sub_bound(max_t_sub);
jit.prepare_with_config(InputSpec::f32(&[max_batch, model.config.d_model, max_t_sub]), &prepare_config)
.context(JitSnafu)?;
HeadDecoder::Ctc { jit, decoder }
}
Head::Rnnt { runtime, .. } => {
let backend = RnntStepBackend::from_model(model.clone()).context(JitSnafu)?;
let decoder = RnntDecoder::new(
runtime.vocabulary.clone(),
RnntOpts { max_symbols_per_step: runtime.max_symbols_per_step },
);
HeadDecoder::Rnnt { backend, decoder, sentencepiece: runtime.sentencepiece }
}
};
Ok(Self { model, opts, splitter, mel, head_decoder, encoder_jit, max_batch, max_t_mel })
}
pub fn encoder_bounds(&self, sample_rate: u32) -> Result<EncoderBounds, TranscribeError<S::Error>> {
self.bounds_with(sample_rate, self.model.config.max_mel_frames)
}
fn prepared_bounds(&self, sample_rate: u32) -> Result<EncoderBounds, TranscribeError<S::Error>> {
self.bounds_with(sample_rate, self.max_t_mel)
}
fn bounds_with(&self, sample_rate: u32, max_mel_frames: usize) -> Result<EncoderBounds, TranscribeError<S::Error>> {
if sample_rate as usize != self.model.config.sample_rate {
return Err(TranscribeError::SampleRateMismatch {
wav_sr: sample_rate,
model_sr: self.model.config.sample_rate as u32,
});
}
Ok(EncoderBounds {
sample_rate,
hop_length: self.model.config.hop_length,
subsampling_factor: self.model.config.subsampling_factor,
max_mel_frames,
})
}
pub fn transcribe(
&mut self,
waveform: &[f32],
sample_rate: u32,
) -> Result<TranscribeResult, TranscribeError<S::Error>> {
let bounds = self.encoder_bounds(sample_rate)?;
let chunks = self.splitter.split(waveform, &bounds).context(SplitterSnafu)?;
self.transcribe_chunks(waveform, sample_rate, &chunks)
}
pub fn transcribe_chunks(
&mut self,
waveform: &[f32],
sample_rate: u32,
chunks: &[AudioChunk],
) -> Result<TranscribeResult, TranscribeError<S::Error>> {
let max_samples = self.prepared_bounds(sample_rate)?.max_samples();
for (idx, chunk) in chunks.iter().enumerate() {
if chunk.end_sample > waveform.len() {
return Err(TranscribeError::ChunkOutOfRange {
idx,
end_sample: chunk.end_sample,
waveform_len: waveform.len(),
});
}
let samples = chunk.end_sample.saturating_sub(chunk.start_sample);
if samples > max_samples {
return Err(TranscribeError::ChunkExceedsCapacity { idx, samples, max_samples });
}
}
let n_mels = self.mel.n_mels();
if chunks.is_empty() {
return Ok(TranscribeResult { text: String::new(), chunks: Vec::new() });
}
let sample_rate_hz = self.model.config.sample_rate;
let d_model = self.model.config.d_model;
let subs_kernel_size = match self.model.config.subsampling_mode {
SubsamplingMode::Conv1d => self.model.config.subs_kernel_size,
SubsamplingMode::Conv2d => 3,
};
let max_t_mel = self.max_t_mel;
let max_t_sub = subs_output_length(subs_kernel_size, max_t_mel);
let max_batch = self.max_batch;
let want_words = self.opts.word_timestamps;
let chunks_meta: Vec<(usize, usize, usize, f32, f32)> = chunks
.iter()
.filter_map(|c| {
let mel_len = self.mel.num_frames(c.end_sample.saturating_sub(c.start_sample));
if mel_len == 0 {
return None;
}
let start_sec = c.start_sample as f32 / sample_rate_hz as f32;
let end_sec = c.end_sample as f32 / sample_rate_hz as f32;
Some((c.start_sample, c.end_sample, mel_len, start_sec, end_sec))
})
.collect();
if chunks_meta.is_empty() {
return Ok(TranscribeResult { text: String::new(), chunks: Vec::new() });
}
let num_chunks = chunks_meta.len();
let mut chunk_results: Vec<ChunkResult> = Vec::with_capacity(num_chunks);
for chunk_batch_start in (0..num_chunks).step_by(max_batch) {
let b = (num_chunks - chunk_batch_start).min(max_batch);
let mut chunk_lengths = vec![0usize; b];
let batch_mels: Vec<Vec<f32>> = (0..b)
.map(|bi| {
let &(start_sample, end_sample, valid, _, _) = &chunks_meta[chunk_batch_start + bi];
let mut chunk_mel = ndarray::Array3::<f32>::zeros((1, n_mels, valid));
{
let mut view = chunk_mel.view_mut().into_dyn();
self.mel.forward_into(&waveform[start_sample..end_sample], &mut view);
}
chunk_mel.as_slice().expect("contiguous chunk mel").to_vec()
})
.collect();
{
let buf = self.encoder_jit.mel_mut().context(JitSnafu)?;
let mut view = buf.as_array_mut::<f32>().context(DeviceSnafu)?;
let slice = view.as_slice_mut().expect("contiguous mel buffer");
slice.fill(0.0);
for (bi, chunk_len) in chunk_lengths.iter_mut().enumerate() {
let &(_, _, valid, _, _) = &chunks_meta[chunk_batch_start + bi];
*chunk_len = valid;
let chunk_mel = &batch_mels[bi];
for mel_bin in 0..n_mels {
let src = mel_bin * valid;
let dst = ((bi * n_mels) + mel_bin) * max_t_mel;
slice[dst..dst + valid].copy_from_slice(&chunk_mel[src..src + valid]);
}
}
}
{
let buf = self.encoder_jit.lengths_mut().context(JitSnafu)?;
let mut view = buf.as_array_mut::<i32>().context(DeviceSnafu)?;
let slice = view.as_slice_mut().expect("contiguous lengths buffer");
slice.fill(0);
for (i, len) in chunk_lengths.iter().enumerate() {
slice[i] = *len as i32;
}
}
let t_exec = chunk_lengths.iter().copied().max().unwrap_or(1).max(1);
let t_exec_sub = subs_output_length(subs_kernel_size, t_exec);
self.encoder_jit.execute_with_vars(&[("b", b as i64), ("t", t_exec as i64)]).context(JitSnafu)?;
match &mut self.head_decoder {
HeadDecoder::Ctc { jit, decoder } => {
{
let n = b * d_model * t_exec_sub;
let src_flat =
self.encoder_jit.output().context(JitSnafu)?.as_array::<f32>().context(DeviceSnafu)?;
let src_3d = src_flat
.slice(ndarray::s![0..n])
.into_shape_with_order((b, d_model, t_exec_sub))
.expect("encoder output reshape");
let dst_flat =
jit.encoded_mut().context(JitSnafu)?.as_array_mut::<f32>().context(DeviceSnafu)?;
let mut dst_3d = dst_flat
.into_shape_with_order((max_batch, d_model, max_t_sub))
.expect("head input reshape");
dst_3d.slice_mut(ndarray::s![0..b, 0..d_model, 0..t_exec_sub]).assign(&src_3d);
}
jit.execute_with_vars(&[("b", b as i64), ("t_sub", t_exec_sub as i64)]).context(JitSnafu)?;
let total_vocab = decoder.total_vocab();
let item_stride = t_exec_sub * total_vocab;
let logits_buf = jit.output().context(JitSnafu)?;
let logits = logits_buf.as_array::<f32>().context(DeviceSnafu)?;
let flat = logits.as_slice().expect("contiguous head logits");
for (bi, mel_len) in chunk_lengths.iter().enumerate() {
let actual_sub = subs_output_length(subs_kernel_size, *mel_len);
let &(start_sample, end_sample, _, start_sec, end_sec) = &chunks_meta[chunk_batch_start + bi];
let chunk_duration_sec = (end_sample - start_sample) as f32 / sample_rate_hz as f32;
let frame_shift = chunk_duration_sec / (actual_sub.max(1) as f32);
let item_slice = &flat[bi * item_stride..bi * item_stride + item_stride];
let (text, frames) = if want_words {
let (text, frames) = decoder
.decode_with_timestamps(item_slice, t_exec_sub, actual_sub)
.context(CtcDecodeSnafu)?;
(text, Some(frames))
} else {
let text = decoder.decode(item_slice, t_exec_sub, actual_sub).context(CtcDecodeSnafu)?;
(text, None)
};
let words = want_words.then(|| {
let frames = frames.as_deref().unwrap_or(&[]);
ctc_frames_to_words(&text, frames, frame_shift)
});
chunk_results.push(ChunkResult { start_sec, end_sec, text, words });
}
}
HeadDecoder::Rnnt { backend, decoder, sentencepiece } => {
let item_stride = d_model * t_exec_sub;
let enc_buf = self.encoder_jit.output().context(JitSnafu)?;
let enc = enc_buf.as_array::<f32>().context(DeviceSnafu)?;
let flat = enc.as_slice().expect("contiguous encoder output");
for (bi, mel_len) in chunk_lengths.iter().enumerate() {
let actual_sub = subs_output_length(subs_kernel_size, *mel_len);
let &(start_sample, end_sample, _, start_sec, end_sec) = &chunks_meta[chunk_batch_start + bi];
let chunk_duration_sec = (end_sample - start_sample) as f32 / sample_rate_hz as f32;
let frame_shift = chunk_duration_sec / (actual_sub.max(1) as f32);
let item_slice = &flat[bi * item_stride..bi * item_stride + item_stride];
let frames = transpose_dt_to_td(item_slice, d_model, t_exec_sub, actual_sub);
let backend: &mut RnntStepBackend = backend;
let (raw, emissions) = if want_words {
let (s, e) = decoder
.decode_with_timestamps(&frames, actual_sub, actual_sub, d_model, backend)
.map_err(rnnt_decode_err)?;
(s, e)
} else {
let s = decoder
.decode(&frames, actual_sub, actual_sub, d_model, backend)
.map_err(rnnt_decode_err)?;
(s, Vec::new())
};
let words = want_words.then(|| decoder.frames_to_words(&emissions, frame_shift));
let text = if *sentencepiece { raw.replace('\u{2581}', " ").trim().to_string() } else { raw };
chunk_results.push(ChunkResult { start_sec, end_sec, text, words });
}
}
}
}
let text =
chunk_results.iter().map(|c| c.text.as_str()).filter(|s| !s.is_empty()).collect::<Vec<_>>().join(" ");
Ok(TranscribeResult { text, chunks: chunk_results })
}
}
fn subs_output_length(kernel_size: usize, mel_frames: usize) -> usize {
let pad = (kernel_size - 1) / 2;
let mut len = mel_frames;
for _ in 0..2 {
len = (len + 2 * pad - kernel_size) / 2 + 1;
}
len
}