shinkai-translator 0.1.3

CLI tool for translating video subtitles with LLMs through OpenAI-compatible APIs, with native PGS OCR
use std::collections::BTreeMap;

use crate::domain::{RenderBlock, RenderPlan, SubtitleCue, SubtitleDocument, SubtitleFormat};
use crate::error::TranslatorError;

use super::{normalize_newlines, parse_arrow_timing_line, push_terminal_newline, split_blocks};

pub fn parse(source: &str) -> Result<SubtitleDocument, TranslatorError> {
    let normalized = normalize_newlines(source);
    let blocks = split_blocks(&normalized);
    if blocks.is_empty() {
        return Err(TranslatorError::Parse("VTT input is empty".to_owned()));
    }

    if !blocks[0].trim_start().starts_with("WEBVTT") {
        return Err(TranslatorError::Parse(
            "VTT files must start with WEBVTT".to_owned(),
        ));
    }

    let mut cues = Vec::new();
    let mut render_blocks = Vec::new();

    for block in blocks {
        if let Some(cue) = parse_cue_block(&block, cues.len() + 1)? {
            cues.push(cue);
            render_blocks.push(RenderBlock::Cue(cues.len() - 1));
        } else {
            render_blocks.push(RenderBlock::Raw(block));
        }
    }

    Ok(SubtitleDocument::from_parts(
        SubtitleFormat::Vtt,
        cues,
        RenderPlan::Vtt {
            blocks: render_blocks,
        },
    ))
}

fn parse_cue_block(block: &str, cue_number: usize) -> Result<Option<SubtitleCue>, TranslatorError> {
    let lines: Vec<&str> = block.lines().collect();
    if lines.is_empty() {
        return Ok(None);
    }

    if lines[0].trim_start().starts_with("WEBVTT")
        || lines[0].trim_start().starts_with("NOTE")
        || lines[0].trim_start().starts_with("STYLE")
        || lines[0].trim_start().starts_with("REGION")
    {
        return Ok(None);
    }

    let (identifier, timing_index) = if lines[0].contains("-->") {
        (None, 0)
    } else if lines.get(1).is_some_and(|line| line.contains("-->")) {
        (Some(lines[0].trim().to_owned()), 1)
    } else {
        return Ok(None);
    };

    let (start, end, settings) = parse_arrow_timing_line(lines[timing_index])?;
    let text = if lines.len() > timing_index + 1 {
        lines[(timing_index + 1)..].join("\n")
    } else {
        String::new()
    };

    Ok(Some(SubtitleCue::new(
        format!("cue-{cue_number}"),
        identifier,
        start,
        end,
        settings,
        text,
        BTreeMap::new(),
    )))
}

pub fn render(document: &SubtitleDocument) -> Result<String, TranslatorError> {
    let RenderPlan::Vtt { blocks } = &document.render_plan else {
        return Err(TranslatorError::UnsupportedFormat(
            "document is not VTT".to_owned(),
        ));
    };

    let rendered = blocks
        .iter()
        .map(|block| match block {
            RenderBlock::Raw(raw) => raw.clone(),
            RenderBlock::Cue(index) => render_cue(&document.cues()[*index]),
        })
        .collect::<Vec<_>>()
        .join("\n\n");

    Ok(push_terminal_newline(rendered))
}

fn render_cue(cue: &SubtitleCue) -> String {
    let mut lines = Vec::new();

    if let Some(identifier) = cue.identifier() {
        lines.push(identifier.to_owned());
    }

    let timing = match cue.settings() {
        Some(settings) => format!("{} --> {} {settings}", cue.start(), cue.end()),
        None => format!("{} --> {}", cue.start(), cue.end()),
    };
    lines.push(timing);

    if !cue.text().is_empty() {
        lines.push(cue.text().to_owned());
    }

    lines.join("\n")
}

#[cfg(test)]
mod tests {
    use super::{parse, render};

    #[test]
    fn preserves_webvtt_header_and_identifier() {
        let source = "WEBVTT\n\nintro\n00:00:01.000 --> 00:00:03.000 line:90%\nhello\nthere\n";
        let document = parse(source).expect("parse should succeed");

        assert_eq!(document.cue_count(), 1);
        assert_eq!(document.cues()[0].identifier(), Some("intro"));

        let rendered = render(&document).expect("render should succeed");
        assert_eq!(rendered, source);
    }
}