native-whisperx 0.1.5

WhisperX-style transcription workflows composed from moritzbrantner Rust building-block crates.
use std::time::Instant;

#[cfg(feature = "diarization")]
use audio_analysis_transcription::TranscriptDiarizationProvider;
use audio_analysis_transcription::{
    transcribe, CandleWhisperTranscriber, EnergyVadTranscriptionProvider,
    ReusableCandleWhisperTranscriber, TranscriptionPipelineRequest, TranscriptionPipelineResponse,
    TranscriptionProviderSelection,
};

use crate::config::{
    AsrProvider, NativeWhisperxConfig, NativeWhisperxError, NativeWhisperxReport, VadMethod,
};
use crate::config_mapping::{
    build_transcription_request, run_native_with_optional_alignment, run_native_with_selected_vad,
};
use crate::output::write_outputs_with_options;
use crate::report::{append_native_alignment_diagnostics, append_native_diarization_diagnostics};

pub fn run(mut config: NativeWhisperxConfig) -> Result<NativeWhisperxReport, NativeWhisperxError> {
    let run_started = Instant::now();
    let mut response = run_transcription_with_cuda_oom_retry(&mut config)?;
    append_native_alignment_diagnostics(&mut response, &config);
    append_native_diarization_diagnostics(&mut response, &config);
    crate::save_draft_speakers_from_response(&mut response, &config)?;
    let output_started = Instant::now();
    let output_files = write_outputs_with_options(
        &response,
        &config.output,
        config.alignment.return_char_alignments,
    )?;
    response.diagnostics.push(format!(
        "phaseOutputSeconds={:.6}",
        output_started.elapsed().as_secs_f64()
    ));
    response.diagnostics.push(format!(
        "phaseNativeTotalSeconds={:.6}",
        run_started.elapsed().as_secs_f64()
    ));
    Ok(NativeWhisperxReport {
        response,
        output_files,
    })
}

pub fn run_many(
    configs: Vec<NativeWhisperxConfig>,
) -> Result<Vec<NativeWhisperxReport>, NativeWhisperxError> {
    if should_reuse_native_asr_provider(&configs) {
        return run_many_reusing_native_provider(configs);
    }
    configs.into_iter().map(run).collect()
}

pub fn run_many_reusing_native_provider(
    configs: Vec<NativeWhisperxConfig>,
) -> Result<Vec<NativeWhisperxReport>, NativeWhisperxError> {
    let mut reports = Vec::with_capacity(configs.len());
    let mut reusable_asr: Option<ReusableCandleWhisperTranscriber> = None;

    for mut config in configs {
        let run_started = Instant::now();
        let mut retry_attempts = Vec::new();
        let mut response = loop {
            let request = build_transcription_request(&config)?;
            let TranscriptionProviderSelection::CandleWhisper(options) = &request.provider else {
                return Err(NativeWhisperxError::InvalidConfig(
                    "native multi-input reuse requires the Candle Whisper native provider"
                        .to_string(),
                ));
            };

            let reused_provider = reusable_asr
                .as_ref()
                .is_some_and(|provider| provider.options == *options);
            if !reused_provider {
                reusable_asr = Some(ReusableCandleWhisperTranscriber::new(options.clone()));
            }
            let asr_provider = reusable_asr
                .as_mut()
                .expect("native ASR provider should be initialized");
            let mut vad = EnergyVadTranscriptionProvider;
            match run_with_reusable_asr(request, &config, &mut vad, asr_provider) {
                Ok(mut response) => {
                    response.diagnostics.push(if reused_provider {
                        "nativeMultiInputAsrProvider=reused".to_string()
                    } else {
                        "nativeMultiInputAsrProvider=loaded".to_string()
                    });
                    push_cuda_oom_retry_diagnostics(
                        &mut response,
                        &retry_attempts,
                        config.asr.max_batch_size,
                    );
                    break response;
                }
                Err(error) if is_cuda_oom_error(&error) => {
                    let Some(next_batch_size) =
                        retry_batch_size_after_cuda_oom(config.asr.max_batch_size)
                    else {
                        return Err(error);
                    };
                    retry_attempts.push(config.asr.max_batch_size);
                    config.asr.max_batch_size = Some(next_batch_size);
                    reusable_asr = None;
                }
                Err(error) => return Err(error),
            }
        };
        append_native_alignment_diagnostics(&mut response, &config);
        append_native_diarization_diagnostics(&mut response, &config);
        crate::save_draft_speakers_from_response(&mut response, &config)?;
        let output_started = Instant::now();
        let output_files = write_outputs_with_options(
            &response,
            &config.output,
            config.alignment.return_char_alignments,
        )?;
        response.diagnostics.push(format!(
            "phaseOutputSeconds={:.6}",
            output_started.elapsed().as_secs_f64()
        ));
        response.diagnostics.push(format!(
            "phaseNativeTotalSeconds={:.6}",
            run_started.elapsed().as_secs_f64()
        ));
        reports.push(NativeWhisperxReport {
            response,
            output_files,
        });
    }

    Ok(reports)
}

fn run_transcription_with_cuda_oom_retry(
    config: &mut NativeWhisperxConfig,
) -> Result<TranscriptionPipelineResponse, NativeWhisperxError> {
    let mut retry_attempts = Vec::new();

    loop {
        let request = build_transcription_request(config)?;
        match run_transcription_once(request, config) {
            Ok(mut response) => {
                push_cuda_oom_retry_diagnostics(
                    &mut response,
                    &retry_attempts,
                    config.asr.max_batch_size,
                );
                return Ok(response);
            }
            Err(error) if is_cuda_oom_error(&error) => {
                let Some(next_batch_size) =
                    retry_batch_size_after_cuda_oom(config.asr.max_batch_size)
                else {
                    return Err(error);
                };
                retry_attempts.push(config.asr.max_batch_size);
                config.asr.max_batch_size = Some(next_batch_size);
            }
            Err(error) => return Err(error),
        }
    }
}

fn run_transcription_once(
    request: TranscriptionPipelineRequest,
    config: &NativeWhisperxConfig,
) -> Result<TranscriptionPipelineResponse, NativeWhisperxError> {
    if config.asr.provider == AsrProvider::Native && config.translation.enabled {
        crate::run_native_with_translation(request, config)
    } else if config.asr.provider == AsrProvider::Native
        && matches!(config.vad.method, VadMethod::Silero | VadMethod::Pyannote)
    {
        run_native_with_selected_vad(request, config)
    } else {
        run_with_phase_observer(request, config)
    }
}

fn is_cuda_oom_error(error: &NativeWhisperxError) -> bool {
    let message = match error {
        NativeWhisperxError::Transcription(message) => message.to_ascii_lowercase(),
        _ => return false,
    };
    message.contains("cuda")
        && (message.contains("out_of_memory") || message.contains("out of memory"))
}

fn retry_batch_size_after_cuda_oom(current: Option<usize>) -> Option<usize> {
    match current {
        None => Some(4),
        Some(batch_size) if batch_size > 1 => Some((batch_size / 2).max(1)),
        Some(_) => None,
    }
}

fn push_cuda_oom_retry_diagnostics(
    response: &mut TranscriptionPipelineResponse,
    retry_attempts: &[Option<usize>],
    final_batch_size: Option<usize>,
) {
    if retry_attempts.is_empty() {
        return;
    }

    response
        .diagnostics
        .push(format!("cudaOomRetryCount={}", retry_attempts.len()));
    response.diagnostics.push(format!(
        "cudaOomRetriedBatchSizes={}",
        retry_attempts
            .iter()
            .map(|batch_size| format_batch_size(*batch_size))
            .collect::<Vec<_>>()
            .join(",")
    ));
    response.diagnostics.push(format!(
        "cudaOomFinalBatchSize={}",
        format_batch_size(final_batch_size)
    ));
}

fn format_batch_size(batch_size: Option<usize>) -> String {
    batch_size
        .map(|batch_size| batch_size.to_string())
        .unwrap_or_else(|| "unbounded".to_string())
}

fn should_reuse_native_asr_provider(configs: &[NativeWhisperxConfig]) -> bool {
    configs.len() > 1
        && configs.iter().all(|config| {
            config.asr.provider == AsrProvider::Native
                && !config.translation.enabled
                && matches!(config.vad.method, VadMethod::Energy)
        })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn cuda_oom_retry_batch_sizes_step_down_to_one() {
        assert_eq!(retry_batch_size_after_cuda_oom(None), Some(4));
        assert_eq!(retry_batch_size_after_cuda_oom(Some(4)), Some(2));
        assert_eq!(retry_batch_size_after_cuda_oom(Some(2)), Some(1));
        assert_eq!(retry_batch_size_after_cuda_oom(Some(1)), None);
    }

    #[test]
    fn cuda_oom_detection_matches_candle_driver_error() {
        let error = NativeWhisperxError::Transcription(
            "model_output_mismatch: Whisper encoder failed: DriverError(CUDA_ERROR_OUT_OF_MEMORY, \"out of memory\")"
                .to_string(),
        );

        assert!(is_cuda_oom_error(&error));
    }
}

fn run_with_reusable_asr(
    request: TranscriptionPipelineRequest,
    config: &NativeWhisperxConfig,
    vad_provider: &mut EnergyVadTranscriptionProvider,
    asr_provider: &mut ReusableCandleWhisperTranscriber,
) -> Result<TranscriptionPipelineResponse, NativeWhisperxError> {
    #[cfg(feature = "diarization")]
    {
        let mut diarizer = crate::native_diarization_provider(config)?;
        let diarization_provider = request
            .diarization
            .enabled
            .then_some(&mut diarizer as &mut dyn TranscriptDiarizationProvider);
        return run_native_with_optional_alignment(
            request,
            vad_provider,
            asr_provider,
            diarization_provider,
        );
    }

    #[cfg(not(feature = "diarization"))]
    {
        let _ = config;
        run_native_with_optional_alignment(request, vad_provider, asr_provider, None)
    }
}

pub(crate) fn run_with_phase_observer(
    request: TranscriptionPipelineRequest,
    config: &NativeWhisperxConfig,
) -> Result<TranscriptionPipelineResponse, NativeWhisperxError> {
    if config.asr.provider != AsrProvider::Native {
        return transcribe(request)
            .map_err(|error| NativeWhisperxError::Transcription(error.to_string()));
    }

    let TranscriptionProviderSelection::CandleWhisper(options) = &request.provider else {
        return transcribe(request)
            .map_err(|error| NativeWhisperxError::Transcription(error.to_string()));
    };
    let mut vad = EnergyVadTranscriptionProvider;
    let mut asr_provider = CandleWhisperTranscriber::new(options.clone());

    #[cfg(feature = "diarization")]
    {
        let mut diarizer = crate::native_diarization_provider(config)?;
        let diarization_provider = request
            .diarization
            .enabled
            .then_some(&mut diarizer as &mut dyn TranscriptDiarizationProvider);
        run_native_with_optional_alignment(
            request,
            &mut vad,
            &mut asr_provider,
            diarization_provider,
        )
    }

    #[cfg(not(feature = "diarization"))]
    {
        run_native_with_optional_alignment(request, &mut vad, &mut asr_provider, None)
    }
}