native-whisperx 0.1.0

WhisperX-style transcription workflows composed from moritzbrantner Rust building-block crates.
Documentation
//! Diagnostic enrichment for native alignment and diarization workflow reports.

use audio_analysis_transcription::TranscriptionPipelineResponse;
use text_transcripts::TranscriptionContract;

use crate::config::{is_pyannote_diarization_model, AsrProvider, NativeWhisperxConfig};

pub(crate) fn append_native_diarization_diagnostics(
    response: &mut TranscriptionPipelineResponse,
    config: &NativeWhisperxConfig,
) {
    if config.asr.provider != AsrProvider::Native
        || !config.diarization.enabled
        || !is_pyannote_diarization_model(&config.diarization.model_id)
    {
        return;
    }

    for diagnostic in [
        "diarizationPhase=segmentation",
        "diarizationPhase=embedding",
        "diarizationPhase=plda",
        "diarizationPhase=vbx",
        "diarizationPhase=clustering",
    ] {
        if !response
            .diagnostics
            .iter()
            .any(|existing| existing == diagnostic)
        {
            response.diagnostics.push(diagnostic.to_string());
        }
    }
}

pub(crate) fn append_native_alignment_diagnostics(
    response: &mut TranscriptionPipelineResponse,
    config: &NativeWhisperxConfig,
) {
    if config.asr.provider != AsrProvider::Native || !config.alignment.enabled {
        return;
    }
    push_diagnostic_if_missing(
        &mut response.diagnostics,
        "alignmentModelId",
        format!(
            "alignmentModelId={}",
            canonical_alignment_model_id(&config.alignment.model_id)
        ),
    );
    push_diagnostic_if_missing(
        &mut response.diagnostics,
        "alignmentFallbackCount",
        "alignmentFallbackCount=0".to_string(),
    );
    push_diagnostic_if_missing(
        &mut response.diagnostics,
        "alignmentRetryCount",
        "alignmentRetryCount=0".to_string(),
    );
    push_diagnostic_if_missing(
        &mut response.diagnostics,
        "alignmentWordTimingMissingCount",
        format!(
            "alignmentWordTimingMissingCount={}",
            alignment_word_timing_missing_count(&response.transcript)
        ),
    );
    push_diagnostic_if_missing(
        &mut response.diagnostics,
        "alignmentCharTimingMissingCount",
        format!(
            "alignmentCharTimingMissingCount={}",
            if config.alignment.return_char_alignments {
                alignment_char_timing_missing_count(&response.transcript)
            } else {
                0
            }
        ),
    );
}

fn canonical_alignment_model_id(model_id: &str) -> &str {
    if model_id.eq_ignore_ascii_case("WAV2VEC2_ASR_BASE_960H") {
        "facebook/wav2vec2-base-960h"
    } else {
        model_id
    }
}

fn push_diagnostic_if_missing(diagnostics: &mut Vec<String>, key: &str, diagnostic: String) {
    let prefix = format!("{key}=");
    if diagnostics
        .iter()
        .any(|existing| existing.starts_with(&prefix))
    {
        return;
    }
    diagnostics.push(diagnostic);
}

fn alignment_word_timing_missing_count(transcript: &TranscriptionContract) -> usize {
    transcript
        .segments
        .iter()
        .flat_map(|segment| segment.words.iter())
        .filter(|word| word.start_seconds.zip(word.end_seconds).is_none())
        .count()
}

fn alignment_char_timing_missing_count(transcript: &TranscriptionContract) -> usize {
    transcript
        .segments
        .iter()
        .flat_map(|segment| segment.chars.iter())
        .filter(|character| character.start_seconds.zip(character.end_seconds).is_none())
        .count()
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::path::PathBuf;

    use audio_analysis_transcription::TranscriptionPipelineResponse;

    use crate::config::{
        AlignmentConfig, AsrConfig, DiarizationConfig, InputSource, NativeWhisperxConfig,
        OutputConfig, TranslationConfig, VadConfig,
    };
    use crate::import_whisperx_json;

    const WHISPERX_SAMPLE: &[u8] =
        include_bytes!("../../../tests/fixtures/whisperx-parity-sample.json");

    #[test]
    fn native_pyannote_diarization_diagnostics_identify_phases() {
        let mut response = fixture_response_with_chars();
        let config = NativeWhisperxConfig {
            input: InputSource::Path {
                path: PathBuf::from("sample.wav"),
            },
            asr: AsrConfig::default(),
            translation: TranslationConfig::default(),
            vad: VadConfig::default(),
            alignment: AlignmentConfig::default(),
            diarization: DiarizationConfig {
                enabled: true,
                model_id: "pyannote/speaker-diarization-community-1".to_string(),
                model_bundle: Some(PathBuf::from("/models/pyannote-diarization")),
                ..DiarizationConfig::default()
            },
            output: OutputConfig::default(),
        };

        append_native_diarization_diagnostics(&mut response, &config);

        for expected in [
            "diarizationPhase=segmentation",
            "diarizationPhase=embedding",
            "diarizationPhase=plda",
            "diarizationPhase=vbx",
            "diarizationPhase=clustering",
        ] {
            assert!(
                response
                    .diagnostics
                    .iter()
                    .any(|diagnostic| diagnostic == expected),
                "missing {expected}: {:?}",
                response.diagnostics
            );
        }
    }

    #[test]
    fn native_alignment_diagnostics_include_fallback_and_retry_counts() {
        let mut response = fixture_response_with_chars();

        append_native_alignment_diagnostics(
            &mut response,
            &NativeWhisperxConfig {
                input: InputSource::Path {
                    path: PathBuf::from("sample.wav"),
                },
                asr: AsrConfig::default(),
                translation: TranslationConfig::default(),
                vad: VadConfig::default(),
                alignment: AlignmentConfig {
                    enabled: true,
                    return_char_alignments: true,
                    ..AlignmentConfig::default()
                },
                diarization: DiarizationConfig::default(),
                output: OutputConfig::default(),
            },
        );

        for expected in [
            "alignmentFallbackCount=0",
            "alignmentRetryCount=0",
            "alignmentWordTimingMissingCount=0",
            "alignmentCharTimingMissingCount=0",
        ] {
            assert!(
                response
                    .diagnostics
                    .iter()
                    .any(|diagnostic| diagnostic == expected),
                "diagnostics should include `{expected}`: {:?}",
                response.diagnostics
            );
        }
    }

    fn fixture_response_with_chars() -> TranscriptionPipelineResponse {
        let mut transcript = import_whisperx_json(WHISPERX_SAMPLE).expect("fixture should import");
        transcript.segments[0]
            .chars
            .push(text_transcripts::TranscriptCharContract {
                character: "h".to_string(),
                start_seconds: Some(0.0),
                end_seconds: Some(0.1),
                confidence: Some(0.9),
                attributes: Default::default(),
            });
        TranscriptionPipelineResponse {
            accepted: true,
            operation: "audio.transcription.transcribe".to_string(),
            provider: "fixture".to_string(),
            model_id: "fixture".to_string(),
            transcript,
            vad_segments: Vec::new(),
            alignment: None,
            diarization: None,
            artifacts: Vec::new(),
            diagnostics: Vec::new(),
        }
    }
}