use std::time::Instant;
#[cfg(feature = "diarization")]
use audio_analysis_transcription::TranscriptDiarizationProvider;
use audio_analysis_transcription::{
transcribe, CandleWhisperTranscriber, EnergyVadTranscriptionProvider,
ReusableCandleWhisperTranscriber, TranscriptionPipelineRequest, TranscriptionPipelineResponse,
TranscriptionProviderSelection,
};
use crate::config::{
AsrProvider, NativeWhisperxConfig, NativeWhisperxError, NativeWhisperxReport, VadMethod,
};
use crate::config_mapping::{
build_transcription_request, run_native_with_optional_alignment, run_native_with_selected_vad,
};
use crate::output::write_outputs_with_options;
use crate::report::{append_native_alignment_diagnostics, append_native_diarization_diagnostics};
pub fn run(mut config: NativeWhisperxConfig) -> Result<NativeWhisperxReport, NativeWhisperxError> {
let run_started = Instant::now();
let mut response = run_transcription_with_cuda_oom_retry(&mut config)?;
append_native_alignment_diagnostics(&mut response, &config);
append_native_diarization_diagnostics(&mut response, &config);
crate::save_draft_speakers_from_response(&mut response, &config)?;
let output_started = Instant::now();
let output_files = write_outputs_with_options(
&response,
&config.output,
config.alignment.return_char_alignments,
)?;
response.diagnostics.push(format!(
"phaseOutputSeconds={:.6}",
output_started.elapsed().as_secs_f64()
));
response.diagnostics.push(format!(
"phaseNativeTotalSeconds={:.6}",
run_started.elapsed().as_secs_f64()
));
Ok(NativeWhisperxReport {
response,
output_files,
})
}
pub fn run_many(
configs: Vec<NativeWhisperxConfig>,
) -> Result<Vec<NativeWhisperxReport>, NativeWhisperxError> {
if should_reuse_native_asr_provider(&configs) {
return run_many_reusing_native_provider(configs);
}
configs.into_iter().map(run).collect()
}
pub fn run_many_reusing_native_provider(
configs: Vec<NativeWhisperxConfig>,
) -> Result<Vec<NativeWhisperxReport>, NativeWhisperxError> {
let mut reports = Vec::with_capacity(configs.len());
let mut reusable_asr: Option<ReusableCandleWhisperTranscriber> = None;
for mut config in configs {
let run_started = Instant::now();
let mut retry_attempts = Vec::new();
let mut alignment_fallback = false;
let mut response = loop {
let request = build_transcription_request(&config)?;
let TranscriptionProviderSelection::CandleWhisper(options) = &request.provider else {
return Err(NativeWhisperxError::InvalidConfig(
"native multi-input reuse requires the Candle Whisper native provider"
.to_string(),
));
};
let reused_provider = reusable_asr
.as_ref()
.is_some_and(|provider| provider.options == *options);
if !reused_provider {
reusable_asr = Some(ReusableCandleWhisperTranscriber::new(options.clone()));
}
let asr_provider = reusable_asr
.as_mut()
.expect("native ASR provider should be initialized");
let mut vad = EnergyVadTranscriptionProvider;
match run_with_reusable_asr(request, &config, &mut vad, asr_provider) {
Ok(mut response) => {
response.diagnostics.push(if reused_provider {
"nativeMultiInputAsrProvider=reused".to_string()
} else {
"nativeMultiInputAsrProvider=loaded".to_string()
});
push_cuda_oom_retry_diagnostics(
&mut response,
&retry_attempts,
config.asr.max_batch_size,
);
push_alignment_fallback_diagnostics(&mut response, alignment_fallback);
break response;
}
Err(error) if is_cuda_oom_error(&error) => {
let Some(next_batch_size) =
retry_batch_size_after_cuda_oom(config.asr.max_batch_size)
else {
return Err(error);
};
retry_attempts.push(config.asr.max_batch_size);
config.asr.max_batch_size = Some(next_batch_size);
reusable_asr = None;
}
Err(error)
if should_retry_without_alignment(&error, &config, alignment_fallback) =>
{
alignment_fallback = true;
config.alignment.enabled = false;
reusable_asr = None;
}
Err(error) => return Err(error),
}
};
append_native_alignment_diagnostics(&mut response, &config);
append_native_diarization_diagnostics(&mut response, &config);
crate::save_draft_speakers_from_response(&mut response, &config)?;
let output_started = Instant::now();
let output_files = write_outputs_with_options(
&response,
&config.output,
config.alignment.return_char_alignments,
)?;
response.diagnostics.push(format!(
"phaseOutputSeconds={:.6}",
output_started.elapsed().as_secs_f64()
));
response.diagnostics.push(format!(
"phaseNativeTotalSeconds={:.6}",
run_started.elapsed().as_secs_f64()
));
reports.push(NativeWhisperxReport {
response,
output_files,
});
}
Ok(reports)
}
fn run_transcription_with_cuda_oom_retry(
config: &mut NativeWhisperxConfig,
) -> Result<TranscriptionPipelineResponse, NativeWhisperxError> {
let mut retry_attempts = Vec::new();
let mut alignment_fallback = false;
loop {
let request = build_transcription_request(config)?;
match run_transcription_once(request, config) {
Ok(mut response) => {
push_cuda_oom_retry_diagnostics(
&mut response,
&retry_attempts,
config.asr.max_batch_size,
);
push_alignment_fallback_diagnostics(&mut response, alignment_fallback);
return Ok(response);
}
Err(error) if is_cuda_oom_error(&error) => {
let Some(next_batch_size) =
retry_batch_size_after_cuda_oom(config.asr.max_batch_size)
else {
return Err(error);
};
retry_attempts.push(config.asr.max_batch_size);
config.asr.max_batch_size = Some(next_batch_size);
}
Err(error) if should_retry_without_alignment(&error, config, alignment_fallback) => {
alignment_fallback = true;
config.alignment.enabled = false;
}
Err(error) => return Err(error),
}
}
}
fn run_transcription_once(
request: TranscriptionPipelineRequest,
config: &NativeWhisperxConfig,
) -> Result<TranscriptionPipelineResponse, NativeWhisperxError> {
if config.asr.provider == AsrProvider::Native && config.translation.enabled {
crate::run_native_with_translation(request, config)
} else if config.asr.provider == AsrProvider::Native
&& matches!(config.vad.method, VadMethod::Silero | VadMethod::Pyannote)
{
run_native_with_selected_vad(request, config)
} else {
run_with_phase_observer(request, config)
}
}
fn is_cuda_oom_error(error: &NativeWhisperxError) -> bool {
let message = match error {
NativeWhisperxError::Transcription(message) => message.to_ascii_lowercase(),
_ => return false,
};
message.contains("cuda")
&& (message.contains("out_of_memory") || message.contains("out of memory"))
}
fn is_alignment_ctc_path_error(error: &NativeWhisperxError) -> bool {
let message = match error {
NativeWhisperxError::Transcription(message) => message.to_ascii_lowercase(),
_ => return false,
};
message.contains("ctc path is impossible")
}
fn should_retry_without_alignment(
error: &NativeWhisperxError,
config: &NativeWhisperxConfig,
already_retried: bool,
) -> bool {
config.asr.provider == AsrProvider::Native
&& config.alignment.enabled
&& !already_retried
&& is_alignment_ctc_path_error(error)
}
fn retry_batch_size_after_cuda_oom(current: Option<usize>) -> Option<usize> {
match current {
None => Some(4),
Some(batch_size) if batch_size > 1 => Some((batch_size / 2).max(1)),
Some(_) => None,
}
}
fn push_alignment_fallback_diagnostics(
response: &mut TranscriptionPipelineResponse,
alignment_fallback: bool,
) {
if !alignment_fallback {
return;
}
response
.diagnostics
.push("alignmentFallbackCount=1".to_string());
response
.diagnostics
.push("alignmentFallbackReason=ctc-path-impossible".to_string());
}
fn push_cuda_oom_retry_diagnostics(
response: &mut TranscriptionPipelineResponse,
retry_attempts: &[Option<usize>],
final_batch_size: Option<usize>,
) {
if retry_attempts.is_empty() {
return;
}
response
.diagnostics
.push(format!("cudaOomRetryCount={}", retry_attempts.len()));
response.diagnostics.push(format!(
"cudaOomRetriedBatchSizes={}",
retry_attempts
.iter()
.map(|batch_size| format_batch_size(*batch_size))
.collect::<Vec<_>>()
.join(",")
));
response.diagnostics.push(format!(
"cudaOomFinalBatchSize={}",
format_batch_size(final_batch_size)
));
}
fn format_batch_size(batch_size: Option<usize>) -> String {
batch_size
.map(|batch_size| batch_size.to_string())
.unwrap_or_else(|| "unbounded".to_string())
}
fn should_reuse_native_asr_provider(configs: &[NativeWhisperxConfig]) -> bool {
configs.len() > 1
&& configs.iter().all(|config| {
config.asr.provider == AsrProvider::Native
&& !config.translation.enabled
&& matches!(config.vad.method, VadMethod::Energy)
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cuda_oom_retry_batch_sizes_step_down_to_one() {
assert_eq!(retry_batch_size_after_cuda_oom(None), Some(4));
assert_eq!(retry_batch_size_after_cuda_oom(Some(4)), Some(2));
assert_eq!(retry_batch_size_after_cuda_oom(Some(2)), Some(1));
assert_eq!(retry_batch_size_after_cuda_oom(Some(1)), None);
}
#[test]
fn cuda_oom_detection_matches_candle_driver_error() {
let error = NativeWhisperxError::Transcription(
"model_output_mismatch: Whisper encoder failed: DriverError(CUDA_ERROR_OUT_OF_MEMORY, \"out of memory\")"
.to_string(),
);
assert!(is_cuda_oom_error(&error));
}
#[test]
fn alignment_ctc_path_detection_matches_runtime_error() {
let error = NativeWhisperxError::Transcription(
"invalid argument: model_output_mismatch: CTC path is impossible".to_string(),
);
assert!(is_alignment_ctc_path_error(&error));
}
}
fn run_with_reusable_asr(
request: TranscriptionPipelineRequest,
config: &NativeWhisperxConfig,
vad_provider: &mut EnergyVadTranscriptionProvider,
asr_provider: &mut ReusableCandleWhisperTranscriber,
) -> Result<TranscriptionPipelineResponse, NativeWhisperxError> {
#[cfg(feature = "diarization")]
{
let mut diarizer = crate::native_diarization_provider(config)?;
let diarization_provider = request
.diarization
.enabled
.then_some(&mut diarizer as &mut dyn TranscriptDiarizationProvider);
return run_native_with_optional_alignment(
request,
vad_provider,
asr_provider,
diarization_provider,
);
}
#[cfg(not(feature = "diarization"))]
{
let _ = config;
run_native_with_optional_alignment(request, vad_provider, asr_provider, None)
}
}
pub(crate) fn run_with_phase_observer(
request: TranscriptionPipelineRequest,
config: &NativeWhisperxConfig,
) -> Result<TranscriptionPipelineResponse, NativeWhisperxError> {
if config.asr.provider != AsrProvider::Native {
return transcribe(request)
.map_err(|error| NativeWhisperxError::Transcription(error.to_string()));
}
let TranscriptionProviderSelection::CandleWhisper(options) = &request.provider else {
return transcribe(request)
.map_err(|error| NativeWhisperxError::Transcription(error.to_string()));
};
let mut vad = EnergyVadTranscriptionProvider;
let mut asr_provider = CandleWhisperTranscriber::new(options.clone());
#[cfg(feature = "diarization")]
{
let mut diarizer = crate::native_diarization_provider(config)?;
let diarization_provider = request
.diarization
.enabled
.then_some(&mut diarizer as &mut dyn TranscriptDiarizationProvider);
run_native_with_optional_alignment(
request,
&mut vad,
&mut asr_provider,
diarization_provider,
)
}
#[cfg(not(feature = "diarization"))]
{
run_native_with_optional_alignment(request, &mut vad, &mut asr_provider, None)
}
}