mod merge;
mod planner;
mod vad;
use crate::config::TranscriptionConfig;
use crate::constants::SAMPLE_RATE;
use crate::error::TranscriptionError;
use crate::models::ModelBundle;
use crate::pipeline::TranscriptionPipeline;
use crate::types::{TimedToken, TranscriptChunk, TranscriptionResult};
pub use planner::{OverlapChunkConfig, VadConfig, VadSegmentationConfig};
use self::merge::merge_overlapping_windows;
use self::planner::{
SampleRange, detect_speech_regions, plan_region_subsegments, region_probability_slice,
};
use self::vad::SileroVad;
#[derive(Debug, Clone)]
pub struct LongFormTranscriptionPipeline {
inner: TranscriptionPipeline,
vad: SileroVad,
default_config: LongFormConfig,
}
#[derive(Debug, Clone, Default)]
pub struct LongFormConfig {
pub transcription: TranscriptionConfig,
pub vad: VadConfig,
pub segmentation: VadSegmentationConfig,
pub overlap: OverlapChunkConfig,
}
impl LongFormTranscriptionPipeline {
pub fn from_dir(models_dir: impl Into<std::path::PathBuf>) -> Result<Self, TranscriptionError> {
let bundle = ModelBundle::from_dir(models_dir);
Self::from_bundle(bundle)
}
pub fn from_bundle(bundle: ModelBundle) -> Result<Self, TranscriptionError> {
bundle.validate_long_form()?;
let inner = TranscriptionPipeline::from_bundle(bundle.clone())?;
let vad = SileroVad::new(bundle.vad_dir())?;
Ok(Self {
inner,
vad,
default_config: LongFormConfig::default(),
})
}
#[cfg(feature = "online")]
pub fn from_pretrained() -> Result<Self, TranscriptionError> {
let bundle = ModelBundle::from_pretrained_long_form().map_err(|error| {
TranscriptionError::CoreMl(format!("model download failed: {error}"))
})?;
Self::from_bundle(bundle)
}
pub fn run(&self, audio: &[f32]) -> Result<TranscriptionResult, TranscriptionError> {
self.run_with_config(audio, &self.default_config)
}
pub fn run_with_config(
&self,
audio: &[f32],
config: &LongFormConfig,
) -> Result<TranscriptionResult, TranscriptionError> {
if audio.is_empty() {
return Err(TranscriptionError::EmptyAudio);
}
if audio.len() <= config.transcription.max_audio_samples {
return self.inner.run_with_config(audio, &config.transcription);
}
self.run_long_form(audio, config)
}
pub fn run_chunk(&self, audio: &[f32]) -> Result<TranscriptionResult, TranscriptionError> {
self.inner.run(audio)
}
fn run_long_form(
&self,
audio: &[f32],
config: &LongFormConfig,
) -> Result<TranscriptionResult, TranscriptionError> {
let probabilities = self.vad.process(audio, &config.vad)?;
let regions = detect_speech_regions(
&probabilities,
audio.len(),
config.segmentation.threshold(config.vad.default_threshold),
&config.segmentation,
);
if regions.is_empty() {
return Ok(TranscriptionResult::empty(duration_seconds(audio.len())));
}
let mut tokens = Vec::new();
let mut chunks = Vec::new();
for region in regions {
let region_tokens = self.transcribe_region(audio, &probabilities, region, config)?;
if let Some(chunk) = build_chunk(®ion_tokens) {
chunks.push(chunk);
tokens.extend(region_tokens);
}
}
Ok(build_result(audio.len(), chunks, tokens))
}
fn transcribe_region(
&self,
audio: &[f32],
probabilities: &[f32],
region: SampleRange,
config: &LongFormConfig,
) -> Result<Vec<TimedToken>, TranscriptionError> {
let region_audio = &audio[region.start..region.end];
if region_audio.len() <= config.transcription.max_audio_samples {
return self.transcribe_single_segment(region_audio, region.start, config);
}
if let Some(subsegments) = plan_region_subsegments(
region,
region_probability_slice(probabilities, region),
&config.segmentation,
config.transcription.max_audio_samples,
) {
return self.transcribe_subsegments(audio, subsegments, config);
}
self.transcribe_overlap_region(audio, region, config)
}
fn transcribe_single_segment(
&self,
audio: &[f32],
sample_offset: usize,
config: &LongFormConfig,
) -> Result<Vec<TimedToken>, TranscriptionError> {
let mut tokens = self
.inner
.run_with_config(audio, &config.transcription)?
.tokens;
offset_tokens(&mut tokens, sample_offset);
Ok(tokens)
}
fn transcribe_subsegments(
&self,
audio: &[f32],
subsegments: Vec<SampleRange>,
config: &LongFormConfig,
) -> Result<Vec<TimedToken>, TranscriptionError> {
let mut tokens = Vec::new();
for subsegment in subsegments {
let sub_audio = &audio[subsegment.start..subsegment.end];
tokens.extend(self.transcribe_single_segment(sub_audio, subsegment.start, config)?);
}
Ok(tokens)
}
fn transcribe_overlap_region(
&self,
audio: &[f32],
region: SampleRange,
config: &LongFormConfig,
) -> Result<Vec<TimedToken>, TranscriptionError> {
let raw_windows = self.transcribe_overlap_windows(audio, region, config)?;
let merged = merge_overlapping_windows(raw_windows);
Ok(self
.inner
.decode_raw(&merged, duration_seconds(audio.len()))
.tokens)
}
fn transcribe_overlap_windows(
&self,
audio: &[f32],
region: SampleRange,
config: &LongFormConfig,
) -> Result<Vec<crate::decode::RawTranscription>, TranscriptionError> {
let mut raw_windows = Vec::new();
for chunk in config.overlap.plan(region) {
let context_start = chunk.start.saturating_sub(config.overlap.context_samples);
let chunk_audio = &audio[context_start..chunk.end];
raw_windows.push(self.inner.transcribe_raw(
chunk_audio,
chunk.start,
chunk.start - context_start,
&config.transcription,
)?);
}
Ok(raw_windows)
}
}
fn join_token_text(tokens: &[TimedToken]) -> String {
tokens
.iter()
.map(|token| token.text.as_str())
.collect::<String>()
.trim()
.to_owned()
}
fn build_chunk(tokens: &[TimedToken]) -> Option<TranscriptChunk> {
Some(TranscriptChunk {
start: tokens.first()?.start,
end: tokens.last()?.end,
text: join_token_text(tokens),
})
}
fn build_result(
audio_len: usize,
chunks: Vec<TranscriptChunk>,
tokens: Vec<TimedToken>,
) -> TranscriptionResult {
let text = chunks
.iter()
.map(|chunk| chunk.text.as_str())
.collect::<Vec<_>>()
.join(" ");
TranscriptionResult {
text,
chunks,
tokens,
duration_seconds: duration_seconds(audio_len),
}
}
fn offset_tokens(tokens: &mut [TimedToken], sample_offset: usize) {
for token in tokens {
offset_token(token, sample_offset);
}
}
fn offset_token(token: &mut TimedToken, sample_offset: usize) {
let seconds = sample_offset as f64 / SAMPLE_RATE as f64;
token.start += seconds;
token.end += seconds;
}
fn duration_seconds(sample_count: usize) -> f64 {
sample_count as f64 / SAMPLE_RATE as f64
}