use crate::decoder::{TimedToken, TranscriptionResult};
use crate::error::{Error, Result};
use crate::execution::ModelConfig as ExecutionConfig;
use crate::model_multitalker::{MultitalkerEncoderCache, MultitalkerModel};
use crate::nemotron::SentencePieceVocab;
use crate::sortformer::{Sortformer, NUM_SPEAKERS};
use crate::timestamps::{self, TimestampMode};
use crate::transcriber::Transcriber;
use ndarray::{s, Array2, Array3};
use std::path::Path;
const SAMPLE_RATE: usize = 16000;
const N_FFT: usize = 512;
const WIN_LENGTH: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 128;
const PREEMPH: f32 = 0.97;
const LOG_ZERO_GUARD: f32 = 5.960_464_5e-8;
const NUM_ENCODER_LAYERS: usize = 24;
const HIDDEN_DIM: usize = 1024;
const LEFT_CONTEXT: usize = 70;
const CONV_CONTEXT: usize = 8;
const VOCAB_SIZE: usize = 1024;
const BLANK_ID: usize = 1024;
const DECODER_LSTM_DIM: usize = 640;
const MAX_SYMBOLS_PER_STEP: usize = 10;
const PRE_ENCODE_CACHE: usize = 9;
const SECONDS_PER_ENCODED_FRAME: f32 = 0.08;
const SPEAKER_ACTIVITY_THRESHOLD: f32 = 0.3;
#[derive(Debug, Clone)]
pub struct WordTimestamp {
pub word: String,
pub start_secs: f32,
pub end_secs: f32,
}
struct SpeakerInstance {
encoder_cache: MultitalkerEncoderCache,
state_1: Array3<f32>,
state_2: Array3<f32>,
last_token: i32,
accumulated_tokens: Vec<(usize, usize)>,
speaker_id: usize,
}
impl SpeakerInstance {
fn new(speaker_id: usize) -> Self {
Self {
encoder_cache: MultitalkerEncoderCache::new(
NUM_ENCODER_LAYERS,
LEFT_CONTEXT,
HIDDEN_DIM,
CONV_CONTEXT,
),
state_1: Array3::zeros((2, 1, DECODER_LSTM_DIM)),
state_2: Array3::zeros((2, 1, DECODER_LSTM_DIM)),
last_token: BLANK_ID as i32,
accumulated_tokens: Vec::new(),
speaker_id,
}
}
}
#[derive(Debug, Clone)]
pub struct SpeakerTranscript {
pub speaker_id: usize,
pub text: String,
pub words: Vec<WordTimestamp>,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum LatencyMode {
#[default]
Normal,
Low,
VeryLow,
Ultra,
}
impl LatencyMode {
pub const fn chunk_mel_frames(self) -> usize {
match self {
Self::Normal => 112, Self::Low => 56, Self::VeryLow => 16, Self::Ultra => 8, }
}
pub const fn encoded_frames(self) -> usize {
match self {
Self::Normal => 14,
Self::Low => 7,
Self::VeryLow => 2,
Self::Ultra => 1,
}
}
pub const fn latency_secs(self) -> f32 {
match self {
Self::Normal => 1.12,
Self::Low => 0.56,
Self::VeryLow => 0.16,
Self::Ultra => 0.08,
}
}
}
#[derive(Debug, Clone)]
pub struct MultitalkerConfig {
pub max_speakers: usize,
pub activity_threshold: f32,
pub latency_mode: LatencyMode,
}
impl Default for MultitalkerConfig {
fn default() -> Self {
Self {
max_speakers: NUM_SPEAKERS,
activity_threshold: SPEAKER_ACTIVITY_THRESHOLD,
latency_mode: LatencyMode::default(),
}
}
}
impl MultitalkerConfig {
pub fn chunk_size(&self) -> usize {
self.latency_mode.chunk_mel_frames()
}
}
pub struct MultitalkerASR {
model: MultitalkerModel,
sortformer: Sortformer,
vocab: SentencePieceVocab,
speakers: Vec<SpeakerInstance>,
config: MultitalkerConfig,
mel_basis: Array2<f32>,
audio_buffer: Vec<f32>,
audio_processed: usize,
chunk_idx: usize,
}
impl MultitalkerASR {
pub fn from_pretrained<P: AsRef<Path>, Q: AsRef<Path>>(
asr_model_dir: P,
sortformer_model_path: Q,
exec_config: Option<ExecutionConfig>,
) -> Result<Self> {
let asr_dir = asr_model_dir.as_ref();
let exec = exec_config.unwrap_or_default();
let vocab = SentencePieceVocab::from_file(asr_dir.join("tokenizer.model"))?;
let model = MultitalkerModel::from_pretrained(asr_dir, exec.clone())?;
let sortformer = Sortformer::with_config(
sortformer_model_path,
Some(exec),
crate::sortformer::DiarizationConfig::default(),
)?;
let mel_basis = crate::audio::create_mel_filterbank(N_FFT, N_MELS, SAMPLE_RATE);
Ok(Self {
model,
sortformer,
vocab,
speakers: Vec::new(),
config: MultitalkerConfig::default(),
mel_basis,
audio_buffer: Vec::new(),
audio_processed: 0,
chunk_idx: 0,
})
}
pub fn reset(&mut self) {
self.speakers.clear();
self.sortformer.reset_state();
self.audio_buffer.clear();
self.audio_processed = 0;
self.chunk_idx = 0;
}
pub fn multitalker_config(&self) -> &MultitalkerConfig {
&self.config
}
pub fn set_max_speakers(&mut self, max_speakers: usize) {
self.config.max_speakers = max_speakers.clamp(1, NUM_SPEAKERS);
}
pub fn set_activity_threshold(&mut self, threshold: f32) {
self.config.activity_threshold = threshold.clamp(0.0, 1.0);
}
pub fn set_latency_mode(&mut self, mode: LatencyMode) {
if self.config.latency_mode != mode {
self.config.latency_mode = mode;
self.reset();
}
}
pub fn chunk_audio_samples(&self) -> usize {
self.config.chunk_size() * HOP_LENGTH
}
pub fn get_transcripts(&self) -> Vec<SpeakerTranscript> {
self.speakers
.iter()
.map(|spk| {
let valid_ids: Vec<usize> = spk
.accumulated_tokens
.iter()
.filter(|&&(t, _)| t < VOCAB_SIZE)
.map(|&(t, _)| t)
.collect();
let words = self.tokens_to_words(&spk.accumulated_tokens);
SpeakerTranscript {
speaker_id: spk.speaker_id,
text: self.vocab.decode(&valid_ids),
words,
}
})
.collect()
}
pub fn transcribe_chunk(&mut self, audio_chunk: &[f32]) -> Result<Vec<SpeakerTranscript>> {
self.audio_buffer.extend_from_slice(audio_chunk);
let total_audio = self.audio_buffer.len();
if total_audio < WIN_LENGTH {
return Ok(vec![]);
}
let full_mel = self.compute_mel_spectrogram(&self.audio_buffer)?;
let total_mel_frames = full_mel.shape()[1];
let processed_mel_frames = self.audio_processed / HOP_LENGTH;
let chunk_size = self.config.chunk_size();
let available_new_frames = total_mel_frames.saturating_sub(processed_mel_frames);
if available_new_frames < chunk_size {
return Ok(vec![]);
}
let raw_preds = self.sortformer.diarize_chunk_raw(audio_chunk)?;
let diar_preds = &raw_preds.predictions;
let mut active_speakers = Vec::new();
for spk_id in 0..self.config.max_speakers {
if spk_id >= diar_preds.ncols() {
break;
}
let max_activity = (0..diar_preds.nrows())
.map(|t| diar_preds[[t, spk_id]])
.fold(0.0f32, f32::max);
if max_activity > self.config.activity_threshold {
active_speakers.push(spk_id);
}
}
for &spk_id in &active_speakers {
if !self.speakers.iter().any(|s| s.speaker_id == spk_id) {
self.speakers.push(SpeakerInstance::new(spk_id));
}
}
let expected_size = PRE_ENCODE_CACHE + chunk_size;
let is_first_chunk = self.chunk_idx == 0;
let main_start = processed_mel_frames;
let mel_chunk = self.build_mel_chunk(&full_mel, main_start, is_first_chunk, expected_size)?;
let chunk_length = expected_size;
let chunk_frame_offset = self.chunk_idx * self.config.latency_mode.encoded_frames();
let mut results = Vec::new();
for &spk_id in &active_speakers {
let (spk_targets, bg_spk_targets) =
self.derive_speaker_targets(diar_preds, spk_id, chunk_length)?;
let spk_idx = self
.speakers
.iter()
.position(|s| s.speaker_id == spk_id)
.unwrap();
let (encoded, enc_len, new_cache) = self.model.run_encoder(
&mel_chunk,
chunk_length as i64,
&self.speakers[spk_idx].encoder_cache,
&spk_targets,
&bg_spk_targets,
)?;
self.speakers[spk_idx].encoder_cache = new_cache;
let tokens = self.decode_chunk_for_speaker(
spk_idx,
&encoded,
enc_len as usize,
chunk_frame_offset,
)?;
self.speakers[spk_idx].accumulated_tokens.extend(&tokens);
let mut text = String::new();
for &(t, _) in &tokens {
if t < VOCAB_SIZE {
text.push_str(&self.vocab.decode_single(t));
}
}
if !text.is_empty() {
let words = self.tokens_to_words(&tokens);
results.push(SpeakerTranscript {
speaker_id: spk_id,
text,
words,
});
}
}
self.audio_processed += chunk_size * HOP_LENGTH;
self.chunk_idx += 1;
let keep_samples = (PRE_ENCODE_CACHE + chunk_size) * HOP_LENGTH + WIN_LENGTH;
if self.audio_buffer.len() > keep_samples * 2 {
let remove = self.audio_buffer.len() - keep_samples;
let actual_remove = remove.min(self.audio_processed);
self.audio_buffer.drain(0..actual_remove);
self.audio_processed -= actual_remove;
}
Ok(results)
}
pub fn transcribe_file_multitalker<P: AsRef<Path>>(
&mut self,
audio_path: P,
) -> Result<Vec<SpeakerTranscript>> {
let (audio, spec) = crate::audio::load_audio(audio_path)?;
if spec.sample_rate != SAMPLE_RATE as u32 {
return Err(Error::Audio(format!(
"Expected {} Hz, got {} Hz",
SAMPLE_RATE, spec.sample_rate
)));
}
let audio = if spec.channels > 1 {
audio
.chunks(spec.channels as usize)
.map(|c| c.iter().sum::<f32>() / spec.channels as f32)
.collect()
} else {
audio
};
self.transcribe_audio_multitalker(&audio)
}
pub fn transcribe_audio_multitalker(
&mut self,
audio: &[f32],
) -> Result<Vec<SpeakerTranscript>> {
self.reset();
let audio_chunk_size = self.chunk_audio_samples();
for chunk in audio.chunks(audio_chunk_size) {
let chunk_vec = if chunk.len() < audio_chunk_size {
let mut p = chunk.to_vec();
p.resize(audio_chunk_size, 0.0);
p
} else {
chunk.to_vec()
};
self.transcribe_chunk(&chunk_vec)?;
}
let flush_chunk = vec![0.0f32; audio_chunk_size];
for _ in 0..3 {
self.transcribe_chunk(&flush_chunk)?;
}
Ok(self.get_transcripts())
}
fn derive_speaker_targets(
&self,
diar_preds: &Array2<f32>,
speaker_id: usize,
encoder_time: usize,
) -> Result<(Array2<f32>, Array2<f32>)> {
let diar_frames = diar_preds.nrows();
let mut spk_vals = Vec::with_capacity(encoder_time);
let mut bg_vals = Vec::with_capacity(encoder_time);
for enc_t in 0..encoder_time {
let diar_t = if diar_frames > 0 && encoder_time > 0 {
(enc_t * diar_frames / encoder_time).min(diar_frames - 1)
} else {
0
};
if diar_t < diar_frames && speaker_id < diar_preds.ncols() {
let spk_val = diar_preds[[diar_t, speaker_id]];
let bg_val = (0..diar_preds.ncols())
.filter(|&j| j != speaker_id)
.map(|j| diar_preds[[diar_t, j]])
.fold(0.0f32, f32::max);
spk_vals.push(spk_val);
bg_vals.push(bg_val);
} else {
spk_vals.push(1.0);
bg_vals.push(0.0);
}
}
let spk_targets = Array2::from_shape_vec((1, encoder_time), spk_vals)
.map_err(|e| Error::Model(format!("spk_targets shape mismatch: {e}")))?;
let bg_spk_targets = Array2::from_shape_vec((1, encoder_time), bg_vals)
.map_err(|e| Error::Model(format!("bg_spk_targets shape mismatch: {e}")))?;
Ok((spk_targets, bg_spk_targets))
}
fn build_mel_chunk(
&self,
full_mel: &Array2<f32>,
main_start: usize,
is_first_chunk: bool,
expected_size: usize,
) -> Result<Array3<f32>> {
let total_mel_frames = full_mel.shape()[1];
let chunk_size = self.config.chunk_size();
let mut chunk_data = vec![0.0f32; N_MELS * expected_size];
if is_first_chunk {
for f in 0..chunk_size.min(total_mel_frames) {
for m in 0..N_MELS {
chunk_data[m * expected_size + PRE_ENCODE_CACHE + f] = full_mel[[m, f]];
}
}
} else {
let cache_start = main_start.saturating_sub(PRE_ENCODE_CACHE);
let cache_frames = main_start - cache_start;
let cache_offset = PRE_ENCODE_CACHE - cache_frames;
for f in 0..cache_frames {
for m in 0..N_MELS {
chunk_data[m * expected_size + cache_offset + f] =
full_mel[[m, cache_start + f]];
}
}
for f in 0..chunk_size.min(total_mel_frames - main_start) {
for m in 0..N_MELS {
chunk_data[m * expected_size + PRE_ENCODE_CACHE + f] =
full_mel[[m, main_start + f]];
}
}
}
Array3::from_shape_vec((1, N_MELS, expected_size), chunk_data)
.map_err(|e| Error::Model(format!("Failed to create mel chunk: {e}")))
}
fn decode_chunk_for_speaker(
&mut self,
spk_idx: usize,
encoder_out: &Array3<f32>,
enc_frames: usize,
chunk_frame_offset: usize,
) -> Result<Vec<(usize, usize)>> {
let mut tokens = Vec::new();
let hidden_dim = encoder_out.shape()[1];
for t in 0..enc_frames {
let frame = encoder_out.slice(s![0, .., t]).to_owned();
let frame = frame
.to_shape((1, 1, hidden_dim))
.map_err(|e| Error::Model(format!("Failed to reshape frame: {e}")))?
.to_owned();
let absolute_frame = chunk_frame_offset + t;
for _ in 0..MAX_SYMBOLS_PER_STEP {
let (logits, new_state_1, new_state_2) = self.model.run_decoder(
&frame,
self.speakers[spk_idx].last_token,
&self.speakers[spk_idx].state_1,
&self.speakers[spk_idx].state_2,
)?;
let mut max_idx = 0;
let mut max_val = f32::NEG_INFINITY;
for (i, &v) in logits.iter().enumerate() {
if v > max_val {
max_val = v;
max_idx = i;
}
}
if max_idx == BLANK_ID {
break;
}
tokens.push((max_idx, absolute_frame));
self.speakers[spk_idx].last_token = max_idx as i32;
self.speakers[spk_idx].state_1 = new_state_1;
self.speakers[spk_idx].state_2 = new_state_2;
}
}
Ok(tokens)
}
fn tokens_to_words(&self, tokens: &[(usize, usize)]) -> Vec<WordTimestamp> {
let timed: Vec<TimedToken> = tokens
.iter()
.filter(|(id, _)| *id < VOCAB_SIZE)
.map(|&(id, frame)| TimedToken {
text: self.vocab.decode_single(id),
start: frame as f32 * SECONDS_PER_ENCODED_FRAME,
end: (frame + 1) as f32 * SECONDS_PER_ENCODED_FRAME,
})
.collect();
timestamps::group_by_words(&timed)
.into_iter()
.map(|t| WordTimestamp {
word: t.text,
start_secs: t.start,
end_secs: t.end,
})
.collect()
}
fn compute_mel_spectrogram(&self, audio: &[f32]) -> Result<Array2<f32>> {
if audio.is_empty() {
return Ok(Array2::zeros((N_MELS, 0)));
}
let preemph = crate::audio::apply_preemphasis(audio, PREEMPH);
let spec = crate::audio::stft(&preemph, N_FFT, HOP_LENGTH, WIN_LENGTH)?;
let mel = self.mel_basis.dot(&spec);
Ok(mel.mapv(|x| (x.max(0.0) + LOG_ZERO_GUARD).ln()))
}
}
impl Transcriber for MultitalkerASR {
fn transcribe_samples(
&mut self,
audio: Vec<f32>,
sample_rate: u32,
channels: u16,
_mode: Option<TimestampMode>,
) -> Result<TranscriptionResult> {
if sample_rate != SAMPLE_RATE as u32 {
return Err(Error::Audio(format!(
"Expected {} Hz, got {} Hz",
SAMPLE_RATE, sample_rate
)));
}
let audio = if channels > 1 {
audio
.chunks(channels as usize)
.map(|c| c.iter().sum::<f32>() / channels as f32)
.collect()
} else {
audio
};
self.reset();
let mel = self.compute_mel_spectrogram(&audio)?;
let total_frames = mel.shape()[1];
if total_frames == 0 {
return Ok(TranscriptionResult {
text: String::new(),
tokens: Vec::new(),
});
}
self.speakers.push(SpeakerInstance::new(0));
let chunk_size = self.config.chunk_size();
let mut buffer_idx = 0;
let mut chunk_idx = 0;
while buffer_idx < total_frames {
let expected_size = PRE_ENCODE_CACHE + chunk_size;
let is_first = chunk_idx == 0;
let mel_chunk = self.build_mel_chunk(&mel, buffer_idx, is_first, expected_size)?;
let chunk_length = expected_size;
let spk_targets = Array2::from_elem((1, chunk_length), 1.0f32);
let bg_spk_targets = Array2::from_elem((1, chunk_length), 0.0f32);
let (encoded, enc_len, new_cache) = self.model.run_encoder(
&mel_chunk,
chunk_length as i64,
&self.speakers[0].encoder_cache,
&spk_targets,
&bg_spk_targets,
)?;
self.speakers[0].encoder_cache = new_cache;
let chunk_frame_offset =
chunk_idx * self.config.latency_mode.encoded_frames();
let tokens =
self.decode_chunk_for_speaker(0, &encoded, enc_len as usize, chunk_frame_offset)?;
self.speakers[0].accumulated_tokens.extend(tokens);
buffer_idx += chunk_size;
chunk_idx += 1;
}
let valid_ids: Vec<usize> = self.speakers[0]
.accumulated_tokens
.iter()
.filter(|&&(t, _)| t < VOCAB_SIZE)
.map(|&(t, _)| t)
.collect();
let text = self.vocab.decode(&valid_ids);
Ok(TranscriptionResult {
text,
tokens: Vec::new(),
})
}
}