shinkai-translator 0.1.3

CLI tool for translating video subtitles with LLMs through OpenAI-compatible APIs, with native PGS OCR
use std::collections::BTreeMap;

use crate::domain::{
    AssCueRenderTemplate, AssRenderBlock, RenderPlan, SubtitleCue, SubtitleDocument, SubtitleFormat,
};
use crate::error::TranslatorError;

use super::{normalize_newlines, push_terminal_newline};

pub fn parse(source: &str) -> Result<SubtitleDocument, TranslatorError> {
    let normalized = normalize_newlines(source);
    let mut lines = Vec::new();
    let mut cues = Vec::new();
    let mut in_events = false;
    let mut format_definition: Option<AssFormatDefinition> = None;

    for line in normalized.lines() {
        let trimmed = line.trim();

        if trimmed.starts_with('[') && trimmed.ends_with(']') {
            in_events = trimmed.eq_ignore_ascii_case("[Events]");
            lines.push(AssRenderBlock::Raw(line.to_owned()));
            continue;
        }

        if in_events {
            if let Some(body) = line.trim_start().strip_prefix("Format:") {
                format_definition = Some(AssFormatDefinition::parse(body)?);
                lines.push(AssRenderBlock::Raw(line.to_owned()));
                continue;
            }

            if let Some((dialogue_prefix, body)) = split_event_prefix(line, "Dialogue:") {
                let definition = format_definition.as_ref().ok_or_else(|| {
                    TranslatorError::Parse(
                        "ASS Dialogue line appeared before a Format line".to_owned(),
                    )
                })?;

                let cue_index = cues.len();
                let (cue, template) =
                    parse_dialogue_line(dialogue_prefix, body, definition, cue_index)?;
                cues.push(cue);
                lines.push(AssRenderBlock::Cue(template));
                continue;
            }
        }

        lines.push(AssRenderBlock::Raw(line.to_owned()));
    }
    Ok(SubtitleDocument::from_parts(
        SubtitleFormat::Ass,
        cues,
        RenderPlan::Ass { lines },
    ))
}

pub fn render(document: &SubtitleDocument) -> Result<String, TranslatorError> {
    let RenderPlan::Ass { lines } = &document.render_plan else {
        return Err(TranslatorError::UnsupportedFormat(
            "document is not ASS".to_owned(),
        ));
    };

    let rendered = lines
        .iter()
        .map(|line| match line {
            AssRenderBlock::Raw(raw) => raw.clone(),
            AssRenderBlock::Cue(template) => {
                render_dialogue_line(&document.cues()[template.cue_index], &template)
            }
        })
        .collect::<Vec<_>>()
        .join("\n");

    Ok(push_terminal_newline(rendered))
}

fn parse_dialogue_line(
    dialogue_prefix: &str,
    body: &str,
    format_definition: &AssFormatDefinition,
    cue_index: usize,
) -> Result<(SubtitleCue, AssCueRenderTemplate), TranslatorError> {
    let raw_values = split_ass_fields_raw(body, format_definition.fields.len())?;
    let mut attributes = BTreeMap::new();
    let mut start = String::new();
    let mut end = String::new();
    let mut text = String::new();

    for (index, field) in format_definition.fields.iter().enumerate() {
        let value = &raw_values[index];
        if index == format_definition.start_index {
            start = value.trim().to_owned();
        } else if index == format_definition.end_index {
            end = value.trim().to_owned();
        } else if index == format_definition.text_index {
            text = value.to_owned();
        } else {
            attributes.insert(field.clone(), value.trim().to_owned());
        }
    }

    let prefix_body = if format_definition.text_index == 0 {
        String::new()
    } else {
        format!("{},", raw_values[..format_definition.text_index].join(","))
    };
    let suffix = if format_definition.text_index + 1 < raw_values.len() {
        format!(
            ",{}",
            raw_values[(format_definition.text_index + 1)..].join(",")
        )
    } else {
        String::new()
    };

    let cue = SubtitleCue::new(
        format!("cue-{}", cue_index + 1),
        None,
        start,
        end,
        None,
        text,
        attributes,
    );
    let template = AssCueRenderTemplate {
        cue_index,
        prefix: format!("{dialogue_prefix}{prefix_body}"),
        suffix,
    };

    Ok((cue, template))
}

fn split_event_prefix<'a>(line: &'a str, prefix: &str) -> Option<(&'a str, &'a str)> {
    let trimmed_start = line.trim_start();
    let leading_whitespace = line.len() - trimmed_start.len();

    if trimmed_start.starts_with(prefix) {
        let prefix_end = leading_whitespace + prefix.len();
        Some((&line[..prefix_end], &line[prefix_end..]))
    } else {
        None
    }
}

fn split_ass_fields_raw(source: &str, expected: usize) -> Result<Vec<String>, TranslatorError> {
    if expected == 0 {
        return Ok(Vec::new());
    }

    let mut values = Vec::with_capacity(expected);
    let mut remaining = source;

    for _ in 0..expected.saturating_sub(1) {
        let (head, tail) = remaining.split_once(',').ok_or_else(|| {
            TranslatorError::Parse(format!(
                "ASS Dialogue line does not match Format field count: {source}"
            ))
        })?;
        values.push(head.to_owned());
        remaining = tail;
    }

    values.push(remaining.to_owned());
    Ok(values)
}

fn render_dialogue_line(cue: &SubtitleCue, template: &AssCueRenderTemplate) -> String {
    format!("{}{}{}", template.prefix, cue.text(), template.suffix)
}

struct AssFormatDefinition {
    fields: Vec<String>,
    start_index: usize,
    end_index: usize,
    text_index: usize,
}

impl AssFormatDefinition {
    fn parse(body: &str) -> Result<Self, TranslatorError> {
        let fields = body
            .split(',')
            .map(|field| field.trim().to_owned())
            .collect::<Vec<_>>();

        let start_index = find_field_index(&fields, "Start")?;
        let end_index = find_field_index(&fields, "End")?;
        let text_index = find_field_index(&fields, "Text")?;

        if text_index + 1 != fields.len() {
            return Err(TranslatorError::Parse(
                "ASS Format Text field must be the last field".to_owned(),
            ));
        }

        Ok(Self {
            fields,
            start_index,
            end_index,
            text_index,
        })
    }
}

fn find_field_index(fields: &[String], name: &str) -> Result<usize, TranslatorError> {
    fields
        .iter()
        .position(|field| field.eq_ignore_ascii_case(name))
        .ok_or_else(|| {
            TranslatorError::Parse(format!("ASS Format line is missing the {name} field"))
        })
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;

    use super::{parse, render};

    #[test]
    fn parses_dialogue_lines_and_roundtrips() {
        let source = "[Script Info]\nTitle: Sample\n\n[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\nDialogue: 0,0:00:01.00,0:00:03.00,Default,,0,0,0,,Hello, world\n";
        let document = parse(source).expect("parse should succeed");

        assert_eq!(document.cue_count(), 1);
        assert_eq!(document.cues()[0].text(), "Hello, world");

        let rendered = render(&document).expect("render should succeed");
        assert_eq!(rendered, source);
    }

    #[test]
    fn preserves_original_dialogue_structure_when_text_changes() {
        let source = "[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n  Dialogue: 0, 0:00:01.00, 0:00:03.00, Default , Speaker, 10,20,30,,Hello, world\n";
        let document = parse(source).expect("parse should succeed");
        let translated = document
            .translated_with(&HashMap::from([(
                String::from("cue-1"),
                String::from("Olá, mundo"),
            )]))
            .expect("translation replacement should succeed");

        let rendered = render(&translated).expect("render should succeed");
        assert_eq!(
            rendered,
            "[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n  Dialogue: 0, 0:00:01.00, 0:00:03.00, Default , Speaker, 10,20,30,,Olá, mundo\n"
        );
    }
}