shinkai-translator 0.1.3

CLI tool for translating video subtitles with LLMs through OpenAI-compatible APIs, with native PGS OCR
use std::collections::{HashMap, HashSet};

use futures::stream::{self, StreamExt, TryStreamExt};

use crate::domain::{SubtitleDocument, TranslationOptions, TranslationResult};
use crate::error::TranslatorError;
use crate::formats;
use crate::providers::{TranslatedItem, TranslationBatch, TranslationProvider};

use super::chunker;

pub async fn translate_document(
    provider: &dyn TranslationProvider,
    document: &SubtitleDocument,
    options: &TranslationOptions,
) -> Result<TranslationResult, TranslatorError> {
    options.validate()?;

    let (classified_document, classification_report) =
        formats::classify_document(document, &options.ass_classification_policy);

    if classified_document.cues().is_empty() {
        let rendered = formats::render_subtitle(&classified_document)?;
        return Ok(TranslationResult::new(
            classified_document,
            rendered,
            0,
            classification_report,
        ));
    }

    let capabilities = provider.capabilities();
    let batches = chunker::chunk_cues(classified_document.cues(), options, &capabilities);

    if batches.is_empty() {
        let rendered = formats::render_subtitle(&classified_document)?;
        return Ok(TranslationResult::new(
            classified_document,
            rendered,
            0,
            classification_report,
        ));
    }

    let mut translations = HashMap::new();
    let source_language = options.source_language.clone();
    let target_language = options.target_language.clone();
    let system_prompt = options.system_prompt.clone();
    let responses = stream::iter(batches.iter().cloned().map(|batch| {
        let source_language = source_language.clone();
        let target_language = target_language.clone();
        let system_prompt = system_prompt.clone();

        async move {
            let request = TranslationBatch {
                source_language,
                target_language,
                system_prompt,
                items: batch.clone(),
            };

            let response = provider.translate_batch(request).await?;
            validate_batch_response(&batch, &response.items)?;

            Ok::<Vec<TranslatedItem>, TranslatorError>(response.items)
        }
    }))
    .buffer_unordered(options.max_parallel_batches)
    .try_collect::<Vec<_>>()
    .await?;

    for batch_items in responses {
        for item in batch_items {
            translations.insert(item.id, item.text);
        }
    }

    let translated_document = classified_document.translated_with(&translations)?;
    let rendered = formats::render_subtitle(&translated_document)?;

    Ok(TranslationResult::new(
        translated_document,
        rendered,
        batches.len(),
        classification_report,
    ))
}

fn validate_batch_response(
    request_items: &[crate::providers::TranslationBatchItem],
    response_items: &[TranslatedItem],
) -> Result<(), TranslatorError> {
    if request_items.len() != response_items.len() {
        return Err(TranslatorError::Validation(format!(
            "provider returned {} items for a batch of {} cues",
            response_items.len(),
            request_items.len()
        )));
    }

    let expected = request_items
        .iter()
        .map(|item| item.id.as_str())
        .collect::<HashSet<_>>();
    let actual = response_items
        .iter()
        .map(|item| item.id.as_str())
        .collect::<HashSet<_>>();

    if expected != actual {
        return Err(TranslatorError::Validation(
            "provider response cue IDs do not match the request".to_owned(),
        ));
    }

    if response_items
        .iter()
        .any(|item| item.text.trim().is_empty())
    {
        return Err(TranslatorError::Validation(
            "provider returned an empty translated text".to_owned(),
        ));
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use std::collections::BTreeMap;
    use std::sync::atomic::{AtomicUsize, Ordering};
    use std::sync::{Arc, Mutex};

    use async_trait::async_trait;
    use tokio::time::{Duration, sleep};

    use crate::domain::{
        CueKind, SubtitleCue, SubtitleDocument, SubtitleFormat, TranslationOptions,
    };
    use crate::formats::ass;
    use crate::providers::{
        TranslatedItem, TranslationBatch, TranslationBatchOutput, TranslationProvider,
    };

    use super::translate_document;

    struct EchoProvider;

    #[async_trait]
    impl TranslationProvider for EchoProvider {
        fn name(&self) -> &str {
            "echo"
        }

        async fn translate_batch(
            &self,
            batch: TranslationBatch,
        ) -> Result<TranslationBatchOutput, crate::error::TranslatorError> {
            Ok(TranslationBatchOutput {
                items: batch
                    .items
                    .into_iter()
                    .map(|item| TranslatedItem {
                        id: item.id,
                        text: format!("translated:{}", item.text),
                    })
                    .collect(),
            })
        }
    }

    #[tokio::test]
    async fn translates_all_cues() {
        let document = SubtitleDocument::from_parts(
            SubtitleFormat::Srt,
            vec![SubtitleCue::new(
                "cue-1",
                Some("1".to_owned()),
                "00:00:00,000",
                "00:00:01,000",
                None,
                "hello",
                BTreeMap::new(),
            )],
            crate::domain::RenderPlan::Srt,
        );
        let options = TranslationOptions::default();

        let result = translate_document(&EchoProvider, &document, &options)
            .await
            .expect("translation should succeed");

        assert_eq!(result.document().cues()[0].text(), "translated:hello");
        assert!(result.rendered().contains("translated:hello"));
        let _ = Arc::new(EchoProvider);
    }

    struct RecordingProvider {
        seen: Arc<Mutex<Vec<String>>>,
    }

    #[async_trait]
    impl TranslationProvider for RecordingProvider {
        fn name(&self) -> &str {
            "recording"
        }

        async fn translate_batch(
            &self,
            batch: TranslationBatch,
        ) -> Result<TranslationBatchOutput, crate::error::TranslatorError> {
            let mut seen = self.seen.lock().expect("seen mutex should lock");
            for item in &batch.items {
                seen.push(item.text.clone());
            }
            drop(seen);

            Ok(TranslationBatchOutput {
                items: batch
                    .items
                    .into_iter()
                    .map(|item| TranslatedItem {
                        id: item.id,
                        text: format!("translated:{}", item.text),
                    })
                    .collect(),
            })
        }
    }

    #[tokio::test]
    async fn preserves_non_translatable_ass_cues() {
        let source = "[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\nDialogue: 0,0:00:01.00,0:00:03.00,Default,,0,0,0,,Hello there\nDialogue: 0,0:00:03.10,0:00:04.00,Default,,0,0,0,,{\\k20}Ka{\\k20}ra\nDialogue: 0,0:00:04.10,0:00:05.00,Lyrics,,0,0,0,,Shining star\n";
        let document = ass::parse(source).expect("parse should succeed");
        let seen = Arc::new(Mutex::new(Vec::new()));
        let provider = RecordingProvider { seen: seen.clone() };

        let result = translate_document(&provider, &document, &TranslationOptions::default())
            .await
            .expect("translation should succeed");

        assert_eq!(result.document().cues()[0].kind(), CueKind::Dialogue);
        assert_eq!(result.document().cues()[0].text(), "translated:Hello there");
        assert_eq!(result.document().cues()[1].kind(), CueKind::Karaoke);
        assert_eq!(result.document().cues()[1].text(), "{\\k20}Ka{\\k20}ra");
        assert_eq!(result.document().cues()[2].kind(), CueKind::Song);
        assert_eq!(result.document().cues()[2].text(), "Shining star");
        assert_eq!(
            seen.lock().expect("seen mutex should lock").as_slice(),
            &["Hello there"]
        );
        assert_eq!(result.batches(), 1);
    }

    struct SlowProvider {
        current: Arc<AtomicUsize>,
        max_seen: Arc<AtomicUsize>,
    }

    #[async_trait]
    impl TranslationProvider for SlowProvider {
        fn name(&self) -> &str {
            "slow"
        }

        async fn translate_batch(
            &self,
            batch: TranslationBatch,
        ) -> Result<TranslationBatchOutput, crate::error::TranslatorError> {
            let current = self.current.fetch_add(1, Ordering::SeqCst) + 1;
            let mut observed = self.max_seen.load(Ordering::SeqCst);
            while current > observed {
                match self.max_seen.compare_exchange(
                    observed,
                    current,
                    Ordering::SeqCst,
                    Ordering::SeqCst,
                ) {
                    Ok(_) => break,
                    Err(next) => observed = next,
                }
            }

            sleep(Duration::from_millis(40)).await;
            self.current.fetch_sub(1, Ordering::SeqCst);

            Ok(TranslationBatchOutput {
                items: batch
                    .items
                    .into_iter()
                    .map(|item| TranslatedItem {
                        id: item.id,
                        text: format!("ok:{}", item.text),
                    })
                    .collect(),
            })
        }
    }

    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn translates_batches_in_parallel() {
        let cues = (0..6)
            .map(|index| {
                SubtitleCue::new(
                    format!("cue-{index}"),
                    None,
                    format!("00:00:0{index},000"),
                    format!("00:00:0{},500", index + 1),
                    None,
                    format!("line-{index}"),
                    BTreeMap::new(),
                )
            })
            .collect::<Vec<_>>();
        let document =
            SubtitleDocument::from_parts(SubtitleFormat::Srt, cues, crate::domain::RenderPlan::Srt);
        let current = Arc::new(AtomicUsize::new(0));
        let max_seen = Arc::new(AtomicUsize::new(0));
        let provider = SlowProvider {
            current: current.clone(),
            max_seen: max_seen.clone(),
        };
        let options = TranslationOptions {
            max_batch_items: 1,
            max_parallel_batches: 4,
            ..TranslationOptions::default()
        };

        let result = translate_document(&provider, &document, &options)
            .await
            .expect("translation should succeed");

        assert_eq!(result.batches(), 6);
        assert!(max_seen.load(Ordering::SeqCst) >= 2);
        assert!(result.rendered().contains("ok:line-0"));
    }
}