shinkai-translator 0.1.3

CLI tool for translating video subtitles with LLMs through OpenAI-compatible APIs, with native PGS OCR
use crate::domain::{
    AssClassificationPolicy, CueClassification, CueKind, SubtitleClassificationEntry,
    SubtitleClassificationReport, SubtitleClassificationSummary, SubtitleCue, SubtitleDocument,
};

pub fn classify_document(
    document: &SubtitleDocument,
    policy: &AssClassificationPolicy,
) -> (SubtitleDocument, SubtitleClassificationReport) {
    let mut classified = document.clone();

    for cue in classified.cues_mut() {
        cue.set_classification(CueClassification::default());

        if let Some(classification) = classify_explicit(cue, policy) {
            cue.set_classification(classification);
        }
    }

    if policy.enable_inferred_song_detection {
        classify_inferred_song_runs(classified.cues_mut(), policy);
    }

    let report = build_report(&classified);
    (classified, report)
}

fn classify_explicit(
    cue: &SubtitleCue,
    policy: &AssClassificationPolicy,
) -> Option<CueClassification> {
    if contains_karaoke_override(cue.text()) {
        return Some(CueClassification::new(
            CueKind::Karaoke,
            policy.karaoke_policy,
            100,
            Some("karaoke timing override detected".to_owned()),
        ));
    }

    if let Some(reason) = metadata_song_reason(cue, policy) {
        return Some(CueClassification::new(
            CueKind::Song,
            policy.explicit_song_policy,
            95,
            Some(reason),
        ));
    }

    None
}

fn classify_inferred_song_runs(cues: &mut [SubtitleCue], policy: &AssClassificationPolicy) {
    let mut index = 0usize;

    while index < cues.len() {
        if cues[index].kind() != CueKind::Dialogue || !looks_like_song_candidate(&cues[index]) {
            index += 1;
            continue;
        }

        let run_start = index;
        let mut run_end = index;

        while run_end + 1 < cues.len()
            && cues[run_end + 1].kind() == CueKind::Dialogue
            && looks_like_song_candidate(&cues[run_end + 1])
            && cues_are_contiguous(&cues[run_end], &cues[run_end + 1])
        {
            run_end += 1;
        }

        let run_length = run_end - run_start + 1;
        if run_length >= policy.min_inferred_song_run_length {
            let reason = format!("inferred song run of {run_length} short contiguous cues");
            for cue in &mut cues[run_start..=run_end] {
                cue.set_classification(CueClassification::new(
                    CueKind::Song,
                    policy.inferred_song_policy,
                    60,
                    Some(reason.clone()),
                ));
            }
        }

        index = run_end + 1;
    }
}

fn build_report(document: &SubtitleDocument) -> SubtitleClassificationReport {
    let entries = document
        .cues()
        .iter()
        .map(|cue| SubtitleClassificationEntry {
            cue_id: cue.id().to_owned(),
            kind: cue.kind(),
            disposition: cue.disposition(),
            confidence: cue.classification_confidence(),
            reason: cue.classification_reason().map(ToOwned::to_owned),
            start: cue.start().to_owned(),
            end: cue.end().to_owned(),
            text_preview: preview_text(cue.text()),
        })
        .collect::<Vec<_>>();

    let summary = SubtitleClassificationSummary {
        total_cues: document.cue_count(),
        translatable_cues: document.translatable_cue_count(),
        preserved_cues: document.preserved_cue_count(),
        review_cues: document.review_cue_count(),
        dialogue_cues: document
            .cues()
            .iter()
            .filter(|cue| cue.kind() == CueKind::Dialogue)
            .count(),
        karaoke_cues: document
            .cues()
            .iter()
            .filter(|cue| cue.kind() == CueKind::Karaoke)
            .count(),
        song_cues: document
            .cues()
            .iter()
            .filter(|cue| cue.kind() == CueKind::Song)
            .count(),
    };

    SubtitleClassificationReport {
        format: document.format(),
        summary,
        entries,
    }
}

fn preview_text(text: &str) -> String {
    const MAX_PREVIEW_CHARS: usize = 120;

    let mut preview = String::new();
    for character in text.chars() {
        let replacement = match character {
            '\n' => ' ',
            _ => character,
        };
        if preview.chars().count() >= MAX_PREVIEW_CHARS {
            preview.push_str("...");
            break;
        }
        preview.push(replacement);
    }
    preview
}

fn contains_karaoke_override(text: &str) -> bool {
    text.as_bytes()
        .windows(2)
        .any(|window| window[0] == b'\\' && matches!(window[1], b'k' | b'K'))
}

fn metadata_song_reason(cue: &SubtitleCue, policy: &AssClassificationPolicy) -> Option<String> {
    match_marker_field("Style", cue, &policy.style_markers)
        .or_else(|| match_marker_field("Effect", cue, &policy.effect_markers))
        .or_else(|| match_marker_field("Name", cue, &policy.name_markers))
}

fn match_marker_field(field: &str, cue: &SubtitleCue, markers: &[String]) -> Option<String> {
    let value = cue.attributes().get(field)?;
    let normalized_value = normalize_marker_string(value);

    for marker in markers {
        let normalized_marker = normalize_marker_string(marker);
        if normalized_marker.is_empty() {
            continue;
        }

        if normalized_marker == "op" || normalized_marker == "ed" {
            if tokenize_ascii_words(value)
                .into_iter()
                .any(|token| song_marker_token(&token, &normalized_marker))
            {
                return Some(format!("song marker `{marker}` matched {field}"));
            }
            continue;
        }

        if normalized_value.contains(&normalized_marker) {
            return Some(format!("song marker `{marker}` matched {field}"));
        }
    }

    None
}

fn normalize_marker_string(value: &str) -> String {
    let mut normalized = String::new();
    let mut previous_was_space = false;

    for character in value.chars() {
        if character.is_ascii_alphanumeric() {
            normalized.push(character.to_ascii_lowercase());
            previous_was_space = false;
        } else if !previous_was_space && !normalized.is_empty() {
            normalized.push(' ');
            previous_was_space = true;
        }
    }

    normalized.trim().to_owned()
}

fn tokenize_ascii_words(value: &str) -> Vec<String> {
    let mut tokens = Vec::new();
    let mut current = String::new();

    for character in value.chars() {
        if character.is_ascii_alphanumeric() {
            current.push(character.to_ascii_lowercase());
        } else if !current.is_empty() {
            tokens.push(std::mem::take(&mut current));
        }
    }

    if !current.is_empty() {
        tokens.push(current);
    }

    tokens
}

fn song_marker_token(token: &str, prefix: &str) -> bool {
    token.strip_prefix(prefix).is_some_and(|rest| {
        rest.is_empty() || rest.chars().all(|character| character.is_ascii_digit())
    })
}

fn looks_like_song_candidate(cue: &SubtitleCue) -> bool {
    let visible_text = visible_text_for_song_detection(cue.text());
    let compact = visible_text.trim();

    if compact.is_empty() {
        return false;
    }

    let word_count = compact.split_whitespace().count();
    let visible_chars = compact
        .chars()
        .filter(|character| !character.is_whitespace())
        .count();
    let has_letters = compact.chars().any(|character| character.is_alphabetic());
    let has_sentence_punctuation = compact
        .chars()
        .any(|character| matches!(character, '.' | '!' | '?' | ';' | ':' | '"'));

    has_letters
        && word_count > 0
        && word_count <= 8
        && visible_chars <= 64
        && !has_sentence_punctuation
}

fn visible_text_for_song_detection(text: &str) -> String {
    let mut visible = String::new();
    let mut in_override = false;
    let mut characters = text.chars();

    while let Some(character) = characters.next() {
        if in_override {
            if character == '}' {
                in_override = false;
            }
            continue;
        }

        match character {
            '{' => in_override = true,
            '\\' => match characters.next() {
                Some('N') | Some('n') | Some('h') => visible.push(' '),
                Some(next) => {
                    visible.push(' ');
                    visible.push(next);
                }
                None => visible.push(' '),
            },
            _ => visible.push(character),
        }
    }

    visible
}

fn cues_are_contiguous(previous: &SubtitleCue, next: &SubtitleCue) -> bool {
    match (
        parse_ass_timestamp(previous.end()),
        parse_ass_timestamp(next.start()),
    ) {
        (Some(previous_end), Some(next_start)) => next_start.saturating_sub(previous_end) <= 250,
        _ => true,
    }
}

fn parse_ass_timestamp(value: &str) -> Option<u64> {
    let mut parts = value.trim().split(':');
    let hours = parts.next()?.parse::<u64>().ok()?;
    let minutes = parts.next()?.parse::<u64>().ok()?;
    let seconds_and_centis = parts.next()?;

    if parts.next().is_some() {
        return None;
    }

    let (seconds, centis) = seconds_and_centis.split_once('.')?;
    let seconds = seconds.parse::<u64>().ok()?;
    let centis = centis.parse::<u64>().ok()?;

    Some(hours * 360_000 + minutes * 6_000 + seconds * 100 + centis)
}

#[cfg(test)]
mod tests {
    use crate::domain::{AssClassificationPolicy, CueDisposition, CueKind};
    use crate::formats::ass;

    use super::classify_document;

    #[test]
    fn classifies_karaoke_and_song_metadata_with_policy() {
        let source = "[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\nDialogue: 0,0:00:01.00,0:00:03.00,Default,,0,0,0,,Hello there!\nDialogue: 0,0:00:03.20,0:00:04.20,Default,,0,0,0,,{\\k20}Ka{\\k20}ra\nDialogue: 0,0:00:05.00,0:00:06.00,Lyrics,,0,0,0,,Shining star\n";
        let document = ass::parse(source).expect("parse should succeed");

        let (classified, report) =
            classify_document(&document, &AssClassificationPolicy::default());

        assert_eq!(classified.cues()[0].kind(), CueKind::Dialogue);
        assert_eq!(
            classified.cues()[0].disposition(),
            CueDisposition::Translate
        );
        assert_eq!(classified.cues()[1].kind(), CueKind::Karaoke);
        assert_eq!(classified.cues()[1].disposition(), CueDisposition::Preserve);
        assert_eq!(classified.cues()[2].kind(), CueKind::Song);
        assert_eq!(classified.cues()[2].disposition(), CueDisposition::Preserve);
        assert_eq!(report.summary.karaoke_cues, 1);
        assert_eq!(report.summary.song_cues, 1);
    }

    #[test]
    fn detects_song_runs_without_explicit_karaoke_tags() {
        let source = "[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\nDialogue: 0,0:00:01.00,0:00:03.00,Q0,,0,0,0,,Muteki no egao\nDialogue: 0,0:00:03.10,0:00:05.00,Q0,,0,0,0,,Shiritai sono himitsu\nDialogue: 0,0:00:05.10,0:00:07.00,Q0,,0,0,0,,Nuketeru toko sae\nDialogue: 0,0:00:07.10,0:00:09.00,Q0,,0,0,0,,Kanpeki de usotsuki na kimi\nDialogue: 0,0:00:12.00,0:00:14.00,Default,,0,0,0,,What happened here?\n";
        let document = ass::parse(source).expect("parse should succeed");

        let (classified, report) =
            classify_document(&document, &AssClassificationPolicy::default());

        assert_eq!(classified.cues()[0].kind(), CueKind::Song);
        assert_eq!(classified.cues()[0].disposition(), CueDisposition::Review);
        assert_eq!(classified.cues()[1].kind(), CueKind::Song);
        assert_eq!(classified.cues()[2].kind(), CueKind::Song);
        assert_eq!(classified.cues()[3].kind(), CueKind::Song);
        assert_eq!(classified.cues()[4].kind(), CueKind::Dialogue);
        assert_eq!(report.summary.review_cues, 4);
    }

    #[test]
    fn can_override_inferred_song_policy_to_translate() {
        let source = "[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\nDialogue: 0,0:00:01.00,0:00:03.00,Q0,,0,0,0,,Muteki no egao\nDialogue: 0,0:00:03.10,0:00:05.00,Q0,,0,0,0,,Shiritai sono himitsu\nDialogue: 0,0:00:05.10,0:00:07.00,Q0,,0,0,0,,Nuketeru toko sae\nDialogue: 0,0:00:07.10,0:00:09.00,Q0,,0,0,0,,Kanpeki de usotsuki na kimi\n";
        let document = ass::parse(source).expect("parse should succeed");
        let policy = AssClassificationPolicy {
            inferred_song_policy: CueDisposition::Translate,
            ..AssClassificationPolicy::default()
        };

        let (classified, report) = classify_document(&document, &policy);

        assert!(
            classified
                .cues()
                .iter()
                .all(|cue| cue.disposition() == CueDisposition::Translate)
        );
        assert_eq!(report.summary.translatable_cues, 4);
        assert_eq!(report.summary.review_cues, 0);
    }
}