Skip to main content

psyche_subtitle_toolkit/
pipeline.rs

1use std::collections::BTreeMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use tempfile::tempdir;
6use tokio::sync::Semaphore;
7
8use crate::error::{Result, SubtitleToolkitError};
9use crate::media::mkv::{
10    SubtitleFormat, discover_mkv_files, extract_subtitle, inspect_mkv, mux_subtitle_in_place,
11    select_subtitle_track,
12};
13use crate::retry::retry_async;
14use crate::subtitles::ass::AssSubtitle;
15use crate::subtitles::model::SubtitleDocument;
16use crate::subtitles::srt::SrtSubtitle;
17use crate::subtitles::vtt::VttSubtitle;
18use crate::subtitles::structured::{
19    apply_translation, chunk_document_by_lines, parse_numbered_text, reinject_tags, strip_tags,
20    to_numbered_text,
21};
22use crate::translation::{TranslationRequest, Translator};
23
24/// Options for the [`translate_mkv`] pipeline.
25#[derive(Debug, Clone)]
26pub struct TranslateMkvOptions {
27    /// Path to an MKV file or directory containing MKV files.
28    pub input: PathBuf,
29    /// Target language code (e.g. `"pt-BR"`, `"en"`, `"ja"`).
30    pub target_language: String,
31    /// Specific subtitle track ID to translate. If `None`, selects the first ASS track.
32    pub track_id: Option<u64>,
33    /// If `true`, preserves extracted/translated ASS files alongside the MKV.
34    pub keep_temp: bool,
35    /// If `true`, shows what would be translated without modifying files.
36    pub dry_run: bool,
37    /// If `true`, saves progress to a file and skips already-translated files on restart.
38    pub resume: bool,
39    /// Maximum number of chunks to translate concurrently. Default: 1 (sequential).
40    pub max_concurrent: usize,
41}
42
43/// Translate ASS subtitles in MKV file(s) and mux the result back in-place.
44///
45/// For each MKV file:
46/// 1. Inspects tracks and selects the ASS subtitle track
47/// 2. Extracts the subtitle to a temp directory
48/// 3. Strips ASS override tags (`{\pos(...)}`, `{\an7}`, etc.)
49/// 4. Chunks cues into 200-line batches
50/// 5. Translates each chunk via the provided [`Translator`]
51/// 6. Re-injects override tags
52/// 7. Muxes the translated subtitle into the MKV (replacing the original track)
53///
54/// If `input` is a directory, processes all `.mkv` files sequentially.
55pub async fn translate_mkv(
56    options: TranslateMkvOptions,
57    translator: Arc<dyn Translator>,
58) -> Result<()> {
59    let files = discover_mkv_files(&options.input).await?;
60
61    let progress_path = progress_file_path(&options.input);
62    let mut completed: Vec<String> = if options.resume && progress_path.exists() {
63        let data = tokio::fs::read_to_string(&progress_path).await?;
64        serde_json::from_str(&data).unwrap_or_default()
65    } else {
66        Vec::new()
67    };
68
69    let total = files.len();
70    for (i, file) in files.into_iter().enumerate() {
71        let file_str = file.to_string_lossy().to_string();
72
73        if options.resume && completed.contains(&file_str) {
74            eprintln!("[resume] skipping ({}/{}): {}", i + 1, total, file.display());
75            continue;
76        }
77
78        translate_one(file, &options, translator.clone()).await?;
79
80        if options.resume {
81            completed.push(file_str);
82            let json = serde_json::to_string_pretty(&completed)?;
83            tokio::fs::write(&progress_path, &json).await?;
84            eprintln!("[resume] progress saved ({}/{})", completed.len(), total);
85        }
86    }
87
88    if options.resume && progress_path.exists() {
89        tokio::fs::remove_file(&progress_path).await?;
90    }
91
92    Ok(())
93}
94
95async fn translate_one(
96    file: PathBuf,
97    options: &TranslateMkvOptions,
98    translator: Arc<dyn Translator>,
99) -> Result<()> {
100    let info = inspect_mkv(&file).await?;
101    let (track, format) = select_subtitle_track(&info, options.track_id)
102        .ok_or_else(|| SubtitleToolkitError::NoSubtitleTrack { path: file.clone() })?;
103
104    let ext = match format {
105        SubtitleFormat::Ass => "ass",
106        SubtitleFormat::Srt => "srt",
107        SubtitleFormat::Vtt => "vtt",
108    };
109
110    let temp_dir = tempdir()?;
111    let extracted_path = temp_dir.path().join(format!("source.{ext}"));
112    let translated_path = temp_dir.path().join(format!("translated.{ext}"));
113
114    let format_label = match format {
115        SubtitleFormat::Ass => "ASS",
116        SubtitleFormat::Srt => "SRT",
117        SubtitleFormat::Vtt => "VTT",
118    };
119    eprintln!("[translate] {} ({})", file.display(), format_label);
120    extract_subtitle(&file, track.id, &extracted_path).await?;
121
122    let source = tokio::fs::read_to_string(&extracted_path).await?;
123
124    let rendered = match format {
125        SubtitleFormat::Ass => {
126            let ass = AssSubtitle::parse(&source)?;
127            if options.dry_run {
128                let summary = dry_run_summary(ass.document(), &options.target_language);
129                println!("[dry-run] {}: {}", file.display(), summary);
130                return Ok(());
131            }
132            translate_ass(
133                ass,
134                &options.target_language,
135                options.max_concurrent,
136                translator,
137            )
138            .await?
139            .render()
140        }
141        SubtitleFormat::Srt => {
142            let srt = SrtSubtitle::parse(&source)?;
143            if options.dry_run {
144                let summary = dry_run_summary(srt.document(), &options.target_language);
145                println!("[dry-run] {}: {}", file.display(), summary);
146                return Ok(());
147            }
148            translate_srt(
149                srt,
150                &options.target_language,
151                options.max_concurrent,
152                translator,
153            )
154            .await?
155            .render()
156        }
157        SubtitleFormat::Vtt => {
158            let vtt = VttSubtitle::parse(&source)?;
159            if options.dry_run {
160                let summary = dry_run_summary(vtt.document(), &options.target_language);
161                println!("[dry-run] {}: {}", file.display(), summary);
162                return Ok(());
163            }
164            translate_vtt(
165                vtt,
166                &options.target_language,
167                options.max_concurrent,
168                translator,
169            )
170            .await?
171            .render()
172        }
173    };
174
175    eprintln!("[translate] muxing translated subtitle");
176    tokio::fs::write(&translated_path, &rendered).await?;
177    mux_subtitle_in_place(&file, track.id, &translated_path, &options.target_language).await?;
178    eprintln!("[translate] done: {}", file.display());
179
180    if options.keep_temp {
181        let persisted = file.with_extension("psyche-subtitle-toolkit-temp");
182        tokio::fs::create_dir_all(&persisted).await?;
183        tokio::fs::copy(&extracted_path, persisted.join(format!("source.{ext}"))).await?;
184        tokio::fs::copy(&translated_path, persisted.join(format!("translated.{ext}"))).await?;
185    }
186
187    Ok(())
188}
189
190fn progress_file_path(input: &std::path::Path) -> PathBuf {
191    let dir = if input.is_dir() {
192        input.to_path_buf()
193    } else if input.extension().is_some() {
194        input.parent().unwrap_or(input).to_path_buf()
195    } else {
196        input.to_path_buf()
197    };
198    dir.join(".psyche-subtitle-toolkit-progress.json")
199}
200
201/// Generate a dry-run summary for a subtitle document: cue count, char count, chunk count.
202pub fn dry_run_summary(doc: &SubtitleDocument, target_language: &str) -> String {
203    let (clean_doc, _) = strip_tags(doc);
204    let chunks = chunk_document_by_lines(&clean_doc, 200);
205    let cue_count = clean_doc.cues.len();
206    let total_chars: usize = clean_doc.cues.iter().map(|c| c.text.len()).sum();
207    format!(
208        "{} cues, {} chars, {} chunk(s) → {}",
209        cue_count,
210        total_chars,
211        chunks.len(),
212        target_language,
213    )
214}
215
216/// Translate an ASS subtitle through the full processing pipeline.
217///
218/// 1. Strips ASS override tags (`{\pos(...)}`, `{\an7}`, etc.)
219/// 2. Chunks cues into 200-line batches
220/// 3. Translates each chunk via the provided [`Translator`]
221/// 4. Applies translated text back to the document
222/// 5. Re-injects the original override tags
223///
224/// Returns the translated [`AssSubtitle`]. Use [`AssSubtitle::render`] to get
225/// the final ASS string.
226pub async fn translate_ass(
227    mut ass: AssSubtitle,
228    target_language: &str,
229    max_concurrent: usize,
230    translator: Arc<dyn Translator>,
231) -> Result<AssSubtitle> {
232    let (mut clean_doc, tag_map) = strip_tags(ass.document());
233    clean_doc = translate_document(clean_doc, target_language, max_concurrent, translator).await?;
234    reinject_tags(&mut clean_doc, &tag_map);
235    *ass.document_mut() = clean_doc;
236    Ok(ass)
237}
238
239/// Translate an SRT subtitle through the processing pipeline.
240///
241/// 1. Chunks cues into 200-line batches
242/// 2. Translates each chunk via the provided [`Translator`]
243/// 3. Applies translated text back to the document
244///
245/// Returns the translated [`SrtSubtitle`]. Use [`SrtSubtitle::render`] to get
246/// the final SRT string.
247pub async fn translate_srt(
248    mut srt: SrtSubtitle,
249    target_language: &str,
250    max_concurrent: usize,
251    translator: Arc<dyn Translator>,
252) -> Result<SrtSubtitle> {
253    let clean_doc = translate_document(
254        srt.document().clone(),
255        target_language,
256        max_concurrent,
257        translator,
258    )
259    .await?;
260    *srt.document_mut() = clean_doc;
261    Ok(srt)
262}
263
264/// Translate a WebVTT subtitle through the processing pipeline.
265///
266/// 1. Chunks cues into 200-line batches
267/// 2. Translates each chunk via the provided [`Translator`]
268/// 3. Applies translated text back to the document
269///
270/// Returns the translated [`VttSubtitle`]. Use [`VttSubtitle::render`] to get
271/// the final WebVTT string.
272pub async fn translate_vtt(
273    mut vtt: VttSubtitle,
274    target_language: &str,
275    max_concurrent: usize,
276    translator: Arc<dyn Translator>,
277) -> Result<VttSubtitle> {
278    let clean_doc = translate_document(
279        vtt.document().clone(),
280        target_language,
281        max_concurrent,
282        translator,
283    )
284    .await?;
285    *vtt.document_mut() = clean_doc;
286    Ok(vtt)
287}
288
289/// Core translation logic shared by ASS, SRT, and VTT pipelines.
290async fn translate_document(
291    mut doc: SubtitleDocument,
292    target_language: &str,
293    max_concurrent: usize,
294    translator: Arc<dyn Translator>,
295) -> Result<SubtitleDocument> {
296    let chunks = chunk_document_by_lines(&doc, 200);
297    let chunk_count = chunks.len();
298    let cue_count = doc.cues.len();
299    let total_chars: usize = doc.cues.iter().map(|c| c.text.len()).sum();
300    eprintln!(
301        "[translate] {} cues, {} chars, {} chunk(s), {} concurrent",
302        cue_count, total_chars, chunk_count, max_concurrent,
303    );
304
305    let semaphore = Arc::new(Semaphore::new(max_concurrent));
306    let mut join_set = tokio::task::JoinSet::new();
307
308    for (i, chunk) in chunks.into_iter().enumerate() {
309        if chunk_count > 1 {
310            let chunk_chars: usize = chunk.cues.iter().map(|c| c.text.len()).sum();
311            eprintln!(
312                "[translate] chunk {}/{}: {} cues, {} chars",
313                i + 1,
314                chunk_count,
315                chunk.cues.len(),
316                chunk_chars,
317            );
318        }
319        let numbered = to_numbered_text(&chunk);
320        let ids: Vec<usize> = chunk.cues.iter().map(|cue| cue.id).collect();
321        let permit = semaphore
322            .clone()
323            .acquire_owned()
324            .await
325            .map_err(|e| SubtitleToolkitError::Translation {
326                provider: "pipeline",
327                message: format!("semaphore closed: {e}"),
328            })?;
329        let translator = translator.clone();
330        let target = target_language.to_string();
331
332        join_set.spawn(async move {
333            let _permit = permit;
334            let numbered_clone = numbered.clone();
335            let ids_clone = ids.clone();
336            let result = retry_async(3, || {
337                let numbered = numbered_clone.clone();
338                let ids = ids_clone.clone();
339                let translator = translator.clone();
340                let target = target.clone();
341                async move {
342                    let translated_text = translator
343                        .translate(TranslationRequest {
344                            source_text: &numbered,
345                            target_language: &target,
346                        })
347                        .await?;
348                    parse_numbered_text(&translated_text, &ids)
349                }
350            })
351            .await;
352            (i, result)
353        });
354    }
355
356    let mut all_translated = BTreeMap::new();
357    let mut results: Vec<(usize, Result<BTreeMap<usize, String>>)> = Vec::new();
358    while let Some(result) = join_set.join_next().await {
359        let (i, outcome) = result.map_err(|e| SubtitleToolkitError::Translation {
360            provider: "pipeline",
361            message: format!("task panicked: {e}"),
362        })?;
363        results.push((i, outcome));
364    }
365    results.sort_by_key(|(i, _)| *i);
366    for (_, result) in results {
367        all_translated.extend(result?);
368    }
369
370    apply_translation(&mut doc, all_translated);
371    Ok(doc)
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::error::SubtitleToolkitError;
378    use crate::subtitles::ass::AssSubtitle;
379    use crate::translation::{TranslationRequest, Translator};
380    use std::sync::Mutex;
381
382    /// A mock translator for integration tests.
383    ///
384    /// Records each source text it receives and returns translations from
385    /// a pre-configured map. If the map has no entry, returns a default
386    /// identity translation (same text back).
387    struct FakeTranslator {
388        /// Source texts received, in call order.
389        received: Mutex<Vec<String>>,
390        /// Maps source text → translated text. If missing, returns source unchanged.
391        responses: std::collections::HashMap<String, String>,
392        /// If set, all calls return this error.
393        error: Option<String>,
394        /// If set, returns responses in order (first call → first response, etc.).
395        /// Used for testing retry on malformed output.
396        sequential: Mutex<Vec<String>>,
397    }
398
399    impl FakeTranslator {
400        fn new(responses: std::collections::HashMap<String, String>) -> Self {
401            Self {
402                received: Mutex::new(Vec::new()),
403                responses,
404                error: None,
405                sequential: Mutex::new(Vec::new()),
406            }
407        }
408
409        fn with_error(message: &str) -> Self {
410            Self {
411                received: Mutex::new(Vec::new()),
412                responses: std::collections::HashMap::new(),
413                error: Some(message.to_string()),
414                sequential: Mutex::new(Vec::new()),
415            }
416        }
417
418        fn with_sequential_responses(responses: Vec<String>) -> Self {
419            Self {
420                received: Mutex::new(Vec::new()),
421                responses: std::collections::HashMap::new(),
422                error: None,
423                sequential: Mutex::new(responses),
424            }
425        }
426
427        fn received_texts(&self) -> Vec<String> {
428            self.received.lock().unwrap().clone()
429        }
430    }
431
432    #[async_trait::async_trait]
433    impl Translator for FakeTranslator {
434        async fn translate(&self, request: TranslationRequest<'_>) -> crate::error::Result<String> {
435            self.received
436                .lock()
437                .unwrap()
438                .push(request.source_text.to_string());
439
440            if let Some(msg) = &self.error {
441                return Err(SubtitleToolkitError::Translation {
442                    provider: "fake",
443                    message: msg.clone(),
444                });
445            }
446
447            // Sequential mode: pop from front of queue
448            {
449                let mut seq = self.sequential.lock().unwrap();
450                if !seq.is_empty() {
451                    return Ok(seq.remove(0));
452                }
453            }
454
455            Ok(self
456                .responses
457                .get(request.source_text)
458                .cloned()
459                .unwrap_or_else(|| request.source_text.to_string()))
460        }
461    }
462
463    const SIMPLE_ASS: &str = r"[Script Info]
464Title: Test
465ScriptType: v4.00+
466
467[V4+ Styles]
468Format: Name, Fontname, Fontsize
469Style: Default,Arial,20
470
471[Events]
472Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
473Dialogue: 0,0:00:01.00,0:00:02.00,Default,,0,0,0,,Hello world
474Dialogue: 0,0:00:03.00,0:00:04.00,Default,,0,0,0,,Goodbye world
475";
476
477    const ASS_WITH_TAGS: &str = r"[Script Info]
478Title: Test Tags
479ScriptType: v4.00+
480
481[V4+ Styles]
482Format: Name, Fontname, Fontsize
483Style: Default,Arial,20
484
485[Events]
486Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
487Dialogue: 0,0:00:01.00,0:00:02.00,Default,,0,0,0,,{\pos(857.6,122.4)}{\an7}Status line
488Dialogue: 0,0:00:03.00,0:00:04.00,Default,,0,0,0,,Normal text
489";
490
491    #[tokio::test]
492    async fn pipeline_translates_dialogue_and_preserves_structure() {
493        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
494
495        let mut responses = std::collections::HashMap::new();
496        responses.insert("<1> Hello world\n<2> Goodbye world".to_string(), "<1> Olá mundo\n<2> Adeus mundo".to_string());
497
498        let translator = Arc::new(FakeTranslator::new(responses));
499        let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
500        let rendered = result.render();
501
502        assert!(rendered.contains("Olá mundo"));
503        assert!(rendered.contains("Adeus mundo"));
504        assert!(!rendered.contains("Hello world"));
505        assert!(!rendered.contains("Goodbye world"));
506
507        // Headers and styles preserved
508        assert!(rendered.contains("[Script Info]"));
509        assert!(rendered.contains("[V4+ Styles]"));
510        assert!(rendered.contains("[Events]"));
511    }
512
513    #[tokio::test]
514    async fn pipeline_passes_numbered_text_and_target_language_to_translator() {
515        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
516
517        let mut responses = std::collections::HashMap::new();
518        responses.insert("<1> Hello world\n<2> Goodbye world".to_string(), "<1> translated1\n<2> translated2".to_string());
519
520        let translator = Arc::new(FakeTranslator::new(responses));
521        translate_ass(ass, "ja", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
522
523        let texts = translator.received_texts();
524        assert_eq!(texts.len(), 1);
525        assert_eq!(texts[0], "<1> Hello world\n<2> Goodbye world");
526    }
527
528    #[tokio::test]
529    async fn pipeline_strips_and_reinjects_override_tags() {
530        let ass = AssSubtitle::parse(ASS_WITH_TAGS).unwrap();
531
532        // The translator receives clean text (no tags)
533        let mut responses = std::collections::HashMap::new();
534        responses.insert(
535            "<1> Status line\n<2> Normal text".to_string(),
536            "<1> Linha de status\n<2> Texto normal".to_string(),
537        );
538
539        let translator = Arc::new(FakeTranslator::new(responses));
540        let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
541        let rendered = result.render();
542
543        // Tags reinjected
544        assert!(rendered.contains(r"{\pos(857.6,122.4)}{\an7}Linha de status"));
545        // Second line had no tags, stays clean
546        assert!(rendered.contains("Texto normal"));
547        assert!(!rendered.contains("Normal text"));
548
549        // Verify translator received clean text (no tags)
550        let texts = translator.received_texts();
551        assert!(!texts[0].contains(r"{\pos"));
552        assert!(!texts[0].contains(r"{\an7}"));
553    }
554
555    #[tokio::test]
556    async fn pipeline_chunks_large_documents() {
557        // Build an ASS with 300 cues to force multiple chunks at 200 lines/chunk.
558        let mut lines = vec![
559            "[Script Info]".to_string(),
560            "Title: Big".to_string(),
561            "ScriptType: v4.00+".to_string(),
562            "".to_string(),
563            "[V4+ Styles]".to_string(),
564            "Format: Name, Fontname, Fontsize".to_string(),
565            "Style: Default,Arial,20".to_string(),
566            "".to_string(),
567            "[Events]".to_string(),
568            "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
569                .to_string(),
570        ];
571
572        for i in 1..=300 {
573            lines.push(format!(
574                "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,subtitle line {i}",
575                i, i + 1,
576            ));
577        }
578
579        let ass_content = lines.join("\n");
580
581        // Verify chunking splits at 200 lines
582        let ass = AssSubtitle::parse(&ass_content).unwrap();
583        let (clean_doc, _) = crate::subtitles::structured::strip_tags(ass.document());
584        let chunks = crate::subtitles::structured::chunk_document_by_lines(&clean_doc, 200);
585        assert_eq!(chunks.len(), 2, "300 cues should produce 2 chunks at 200 lines");
586        assert_eq!(chunks[0].cues.len(), 200);
587        assert_eq!(chunks[1].cues.len(), 100);
588
589        // Build a FakeTranslator (identity translation)
590        let translator = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
591
592        let ass = AssSubtitle::parse(&ass_content).unwrap();
593        let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
594        let rendered = result.render();
595
596        // All 300 cues should be present
597        for i in 1..=300 {
598            assert!(
599                rendered.contains(&format!("subtitle line {i}")),
600                "missing cue {i} in rendered output"
601            );
602        }
603
604        // Translator was called twice (2 chunks)
605        let texts = translator.received_texts();
606        assert_eq!(texts.len(), 2);
607
608        // All 300 cue IDs accounted for across all calls
609        let mut all_ids: Vec<usize> = Vec::new();
610        for text in &texts {
611            for line in text.lines() {
612                if let Some(start) = line.find('<')
613                    && let Some(end) = line[start + 1..].find('>')
614                    && let Ok(id) = line[start + 1..start + 1 + end].parse::<usize>()
615                {
616                    all_ids.push(id);
617                }
618            }
619        }
620        all_ids.sort();
621        assert_eq!(all_ids, (1..=300).collect::<Vec<_>>());
622    }
623
624    #[tokio::test]
625    async fn pipeline_propagates_translator_error() {
626        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
627        let translator = FakeTranslator::with_error("rate limit exceeded");
628
629        let err = translate_ass(ass, "pt-BR", 1, Arc::new(translator) as Arc<dyn Translator>).await.unwrap_err();
630
631        assert!(err.to_string().contains("fake"));
632        assert!(err.to_string().contains("rate limit exceeded"));
633    }
634
635    #[tokio::test]
636    async fn pipeline_rejects_incomplete_translation() {
637        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
638
639        // Translator returns only one of two expected IDs
640        let mut responses = std::collections::HashMap::new();
641        responses.insert(
642            "<1> Hello world\n<2> Goodbye world".to_string(),
643            "<1> Olá mundo".to_string(), // missing <2>
644        );
645
646        let translator = FakeTranslator::new(responses);
647        let err = translate_ass(ass, "pt-BR", 1, Arc::new(translator) as Arc<dyn Translator>).await.unwrap_err();
648
649        assert!(err.to_string().contains("missing id <2>"));
650    }
651
652    #[tokio::test]
653    async fn pipeline_handles_multiline_cues() {
654        let ass_content = r"[Script Info]
655Title: Multiline
656ScriptType: v4.00+
657
658[V4+ Styles]
659Format: Name, Fontname, Fontsize
660Style: Default,Arial,20
661
662[Events]
663Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
664Dialogue: 0,0:00:01.00,0:00:02.00,Default,,0,0,0,,First line\NSecond line
665";
666
667        let ass = AssSubtitle::parse(ass_content).unwrap();
668
669        let mut responses = std::collections::HashMap::new();
670        responses.insert(
671            "<1> First line\\NSecond line".to_string(),
672            "<1> Primeira linha\\NSegunda linha".to_string(),
673        );
674
675        let translator = Arc::new(FakeTranslator::new(responses));
676        let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
677        let rendered = result.render();
678
679        assert!(rendered.contains("Primeira linha"));
680        assert!(rendered.contains("Segunda linha"));
681    }
682
683    #[tokio::test]
684    async fn pipeline_translates_basic_document() {
685        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
686
687        let mut responses = std::collections::HashMap::new();
688        responses.insert(
689            "<1> Hello world\n<2> Goodbye world".to_string(),
690            "<1> translated1\n<2> translated2".to_string(),
691        );
692
693        let translator = Arc::new(FakeTranslator::new(responses));
694        translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>)
695            .await
696            .unwrap();
697
698        let texts = translator.received_texts();
699        assert_eq!(texts.len(), 1);
700    }
701
702    #[test]
703    fn dry_run_summary_reports_cues_chars_chunks() {
704        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
705        let summary = dry_run_summary(ass.document(), "pt-BR");
706
707        assert!(summary.contains("2 cues"), "expected '2 cues' in: {summary}");
708        assert!(summary.contains("1 chunk(s)"), "expected '1 chunk(s)' in: {summary}");
709        assert!(summary.contains("→ pt-BR"), "expected '→ pt-BR' in: {summary}");
710    }
711
712    #[test]
713    fn dry_run_summary_counts_chars() {
714        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
715        let summary = dry_run_summary(ass.document(), "en");
716
717        // "Hello world" = 11 chars, "Goodbye world" = 13 chars = 24 total
718        assert!(summary.contains("24 chars"), "expected '24 chars' in: {summary}");
719    }
720
721    #[test]
722    fn dry_run_summary_handles_empty_document() {
723        let ass_content = r"[Script Info]
724Title: Empty
725ScriptType: v4.00+
726
727[V4+ Styles]
728Format: Name, Fontname, Fontsize
729Style: Default,Arial,20
730
731[Events]
732Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
733";
734        let ass = AssSubtitle::parse(ass_content).unwrap();
735        let summary = dry_run_summary(ass.document(), "de");
736
737        assert!(summary.contains("0 cues"), "expected '0 cues' in: {summary}");
738        assert!(summary.contains("0 chars"), "expected '0 chars' in: {summary}");
739        assert!(summary.contains("0 chunk(s)"), "expected '0 chunk(s)' in: {summary}");
740    }
741
742    #[test]
743    fn dry_run_summary_splits_large_documents() {
744        // Build ASS with 100 cues to force multiple chunks
745        let mut lines = vec![
746            "[Script Info]".to_string(),
747            "Title: Big".to_string(),
748            "ScriptType: v4.00+".to_string(),
749            "".to_string(),
750            "[V4+ Styles]".to_string(),
751            "Format: Name, Fontname, Fontsize".to_string(),
752            "Style: Default,Arial,20".to_string(),
753            "".to_string(),
754            "[Events]".to_string(),
755            "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
756                .to_string(),
757        ];
758        for i in 1..=300 {
759            lines.push(format!(
760                "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,This is subtitle line number {i} with enough text to fill space",
761                i, i + 1,
762            ));
763        }
764        let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
765        let summary = dry_run_summary(ass.document(), "pt-BR");
766
767        assert!(summary.contains("300 cues"), "expected '300 cues' in: {summary}");
768        assert!(
769            summary.contains("2 chunk(s)"),
770            "expected '2 chunk(s)' in: {summary}"
771        );
772    }
773
774    #[tokio::test]
775    async fn pipeline_retries_chunk_on_malformed_output() {
776        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
777
778        // First call: returns broken output (missing <2>)
779        // Second call: returns correct output
780        let translator = Arc::new(FakeTranslator::with_sequential_responses(vec![
781            "<1> Olá mundo".to_string(),           // missing <2>
782            "<1> Olá mundo\n<2> Adeus mundo".to_string(), // correct
783        ]));
784
785        let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
786        let rendered = result.render();
787
788        assert!(rendered.contains("Olá mundo"));
789        assert!(rendered.contains("Adeus mundo"));
790
791        // Should have been called twice (1 failed + 1 success)
792        let texts = translator.received_texts();
793        assert_eq!(texts.len(), 2);
794    }
795
796    #[tokio::test]
797    async fn pipeline_gives_up_after_repeated_malformed_output() {
798        let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
799
800        // Always returns broken output (missing <2>)
801        let translator = Arc::new(FakeTranslator::with_sequential_responses(vec![
802            "<1> Olá mundo".to_string(),
803            "<1> Olá mundo".to_string(),
804            "<1> Olá mundo".to_string(),
805            "<1> Olá mundo".to_string(), // 4 attempts total (1 initial + 3 retries)
806        ]));
807
808        let err = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>)
809            .await
810            .unwrap_err();
811
812        assert!(err.to_string().contains("missing id <2>"));
813
814        // Should have been called 4 times (1 initial + 3 retries)
815        let texts = translator.received_texts();
816        assert_eq!(texts.len(), 4);
817    }
818
819    #[test]
820    fn progress_file_path_for_directory() {
821        let path = progress_file_path(std::path::Path::new("/media/anime"));
822        assert_eq!(
823            path,
824            std::path::PathBuf::from("/media/anime/.psyche-subtitle-toolkit-progress.json")
825        );
826    }
827
828    #[test]
829    fn progress_file_path_for_file() {
830        let path = progress_file_path(std::path::Path::new("/media/anime/episode.mkv"));
831        assert_eq!(
832            path,
833            std::path::PathBuf::from("/media/anime/.psyche-subtitle-toolkit-progress.json")
834        );
835    }
836
837    #[tokio::test]
838    async fn progress_file_roundtrip() {
839        let dir = tempfile::tempdir().unwrap();
840        let progress_path = dir.path().join(".psyche-subtitle-toolkit-progress.json");
841
842        let completed = vec![
843            "/media/anime/ep1.mkv".to_string(),
844            "/media/anime/ep2.mkv".to_string(),
845        ];
846        let json = serde_json::to_string_pretty(&completed).unwrap();
847        tokio::fs::write(&progress_path, &json).await.unwrap();
848
849        let data = tokio::fs::read_to_string(&progress_path).await.unwrap();
850        let loaded: Vec<String> = serde_json::from_str(&data).unwrap();
851
852        assert_eq!(loaded, completed);
853    }
854
855    #[tokio::test]
856    async fn progress_file_handles_missing_file() {
857        let dir = tempfile::tempdir().unwrap();
858        let progress_path = dir.path().join(".psyche-subtitle-toolkit-progress.json");
859
860        // Should return empty vec when file doesn't exist
861        let completed: Vec<String> = if progress_path.exists() {
862            let data = tokio::fs::read_to_string(&progress_path).await.unwrap();
863            serde_json::from_str(&data).unwrap_or_default()
864        } else {
865            Vec::new()
866        };
867
868        assert!(completed.is_empty());
869    }
870
871    #[tokio::test]
872    async fn progress_file_handles_corrupted_json() {
873        let dir = tempfile::tempdir().unwrap();
874        let progress_path = dir.path().join(".psyche-subtitle-toolkit-progress.json");
875
876        tokio::fs::write(&progress_path, "not valid json")
877            .await
878            .unwrap();
879
880        let data = tokio::fs::read_to_string(&progress_path).await.unwrap();
881        let completed: Vec<String> = serde_json::from_str(&data).unwrap_or_default();
882
883        // Should fall back to empty vec on corrupt JSON
884        assert!(completed.is_empty());
885    }
886
887    #[tokio::test]
888    async fn pipeline_translates_concurrently() {
889        // Build ASS with 300 cues → 2 chunks at 200 lines/chunk
890        let mut lines = vec![
891            "[Script Info]".to_string(),
892            "Title: Concurrent".to_string(),
893            "ScriptType: v4.00+".to_string(),
894            "".to_string(),
895            "[V4+ Styles]".to_string(),
896            "Format: Name, Fontname, Fontsize".to_string(),
897            "Style: Default,Arial,20".to_string(),
898            "".to_string(),
899            "[Events]".to_string(),
900            "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
901                .to_string(),
902        ];
903        for i in 1..=300 {
904            lines.push(format!(
905                "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,line {i}",
906                i, i + 1,
907            ));
908        }
909        let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
910        let translator = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
911
912        // Translate with max_concurrent=2 (both chunks in parallel)
913        let result = translate_ass(
914            ass,
915            "pt-BR",
916            2,
917            translator.clone() as Arc<dyn Translator>,
918        )
919        .await
920        .unwrap();
921
922        let rendered = result.render();
923        for i in 1..=300 {
924            assert!(rendered.contains(&format!("line {i}")), "missing cue {i}");
925        }
926
927        // Both chunks should have been translated
928        let texts = translator.received_texts();
929        assert_eq!(texts.len(), 2);
930    }
931
932    #[tokio::test]
933    async fn concurrent_translation_preserves_all_cues() {
934        // Stress test: 1000 cues, 5 chunks, max_concurrent=5
935        let mut lines = vec![
936            "[Script Info]".to_string(),
937            "Title: Stress".to_string(),
938            "ScriptType: v4.00+".to_string(),
939            "".to_string(),
940            "[V4+ Styles]".to_string(),
941            "Format: Name, Fontname, Fontsize".to_string(),
942            "Style: Default,Arial,20".to_string(),
943            "".to_string(),
944            "[Events]".to_string(),
945            "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
946                .to_string(),
947        ];
948        for i in 1..=1000 {
949            lines.push(format!(
950                "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,stress line {i}",
951                i, i + 1,
952            ));
953        }
954        let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
955        let translator = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
956
957        let result = translate_ass(
958            ass,
959            "pt-BR",
960            5,
961            translator.clone() as Arc<dyn Translator>,
962        )
963        .await
964        .unwrap();
965
966        let rendered = result.render();
967        // Every single cue must be present
968        for i in 1..=1000 {
969            assert!(
970                rendered.contains(&format!("stress line {i}")),
971                "missing cue {i} under concurrent translation"
972            );
973        }
974
975        // 1000 cues / 200 lines per chunk = 5 chunks
976        let texts = translator.received_texts();
977        assert_eq!(texts.len(), 5, "expected 5 chunk calls, got {}", texts.len());
978    }
979
980    #[tokio::test]
981    async fn concurrent_translation_output_is_deterministic() {
982        // Run the same document through concurrent translation twice.
983        // Output must be identical regardless of task scheduling.
984        let make_ass = || {
985            let mut lines = vec![
986                "[Script Info]".to_string(),
987                "Title: Deterministic".to_string(),
988                "ScriptType: v4.00+".to_string(),
989                "".to_string(),
990                "[V4+ Styles]".to_string(),
991                "Format: Name, Fontname, Fontsize".to_string(),
992                "Style: Default,Arial,20".to_string(),
993                "".to_string(),
994                "[Events]".to_string(),
995                "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
996                    .to_string(),
997            ];
998            for i in 1..=500 {
999                lines.push(format!(
1000                    "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,det line {i}",
1001                    i, i + 1,
1002                ));
1003            }
1004            AssSubtitle::parse(&lines.join("\n")).unwrap()
1005        };
1006
1007        let t1 = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
1008        let r1 = translate_ass(
1009            make_ass(),
1010            "pt-BR",
1011            3,
1012            t1.clone() as Arc<dyn Translator>,
1013        )
1014        .await
1015        .unwrap();
1016
1017        let t2 = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
1018        let r2 = translate_ass(
1019            make_ass(),
1020            "pt-BR",
1021            3,
1022            t2.clone() as Arc<dyn Translator>,
1023        )
1024        .await
1025        .unwrap();
1026
1027        assert_eq!(r1.render(), r2.render(), "concurrent output is non-deterministic");
1028    }
1029
1030    #[tokio::test]
1031    async fn concurrent_error_propagates_correctly() {
1032        // First chunk succeeds, second chunk always fails.
1033        // The pipeline should return an error, not silently succeed.
1034        let mut lines = vec![
1035            "[Script Info]".to_string(),
1036            "Title: ErrProp".to_string(),
1037            "ScriptType: v4.00+".to_string(),
1038            "".to_string(),
1039            "[V4+ Styles]".to_string(),
1040            "Format: Name, Fontname, Fontsize".to_string(),
1041            "Style: Default,Arial,20".to_string(),
1042            "".to_string(),
1043            "[Events]".to_string(),
1044            "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
1045                .to_string(),
1046        ];
1047        for i in 1..=400 {
1048            lines.push(format!(
1049                "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,line {i}",
1050                i, i + 1,
1051            ));
1052        }
1053        let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
1054
1055        // Always error — all chunks will fail
1056        let translator = Arc::new(FakeTranslator::with_error("provider down"));
1057
1058        let err = translate_ass(
1059            ass,
1060            "pt-BR",
1061            3,
1062            translator.clone() as Arc<dyn Translator>,
1063        )
1064        .await
1065        .unwrap_err();
1066
1067        assert!(err.to_string().contains("provider down"));
1068    }
1069
1070    /// A translator that tracks the maximum number of concurrent translate calls.
1071    /// Used to verify the semaphore actually bounds concurrency.
1072    struct ConcurrencyTrackingTranslator {
1073        active: std::sync::atomic::AtomicU32,
1074        max_observed: std::sync::atomic::AtomicU32,
1075        received: Mutex<Vec<String>>,
1076    }
1077
1078    impl ConcurrencyTrackingTranslator {
1079        fn new() -> Self {
1080            Self {
1081                active: std::sync::atomic::AtomicU32::new(0),
1082                max_observed: std::sync::atomic::AtomicU32::new(0),
1083                received: Mutex::new(Vec::new()),
1084            }
1085        }
1086
1087        fn max_concurrent_calls(&self) -> u32 {
1088            self.max_observed.load(std::sync::atomic::Ordering::SeqCst)
1089        }
1090    }
1091
1092    #[async_trait::async_trait]
1093    impl Translator for ConcurrencyTrackingTranslator {
1094        async fn translate(&self, request: TranslationRequest<'_>) -> crate::error::Result<String> {
1095            self.received.lock().unwrap().push(request.source_text.to_string());
1096
1097            let current = self.active.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
1098            // Update max_observed
1099            self.max_observed.fetch_max(current, std::sync::atomic::Ordering::SeqCst);
1100
1101            // Simulate work — yield to let other tasks run
1102            tokio::task::yield_now().await;
1103
1104            self.active.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
1105
1106            // Identity translation
1107            Ok(request.source_text.to_string())
1108        }
1109    }
1110
1111    #[tokio::test]
1112    async fn semaphore_bounds_concurrency() {
1113        // 600 cues → 3 chunks at 200 lines/chunk, max_concurrent=2
1114        // The semaphore should prevent all 3 from running simultaneously.
1115        let mut lines = vec![
1116            "[Script Info]".to_string(),
1117            "Title: Semaphore".to_string(),
1118            "ScriptType: v4.00+".to_string(),
1119            "".to_string(),
1120            "[V4+ Styles]".to_string(),
1121            "Format: Name, Fontname, Fontsize".to_string(),
1122            "Style: Default,Arial,20".to_string(),
1123            "".to_string(),
1124            "[Events]".to_string(),
1125            "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
1126                .to_string(),
1127        ];
1128        for i in 1..=600 {
1129            lines.push(format!(
1130                "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,sem line {i}",
1131                i, i + 1,
1132            ));
1133        }
1134        let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
1135        let translator = Arc::new(ConcurrencyTrackingTranslator::new());
1136
1137        let result = translate_ass(
1138            ass,
1139            "pt-BR",
1140            2, // max_concurrent=2
1141            translator.clone() as Arc<dyn Translator>,
1142        )
1143        .await
1144        .unwrap();
1145
1146        // All 600 cues should be present
1147        let rendered = result.render();
1148        for i in 1..=600 {
1149            assert!(rendered.contains(&format!("sem line {i}")), "missing cue {i}");
1150        }
1151
1152        // The max observed concurrency should be <= 2
1153        let max = translator.max_concurrent_calls();
1154        assert!(
1155            max <= 2,
1156            "semaphore failed to bound concurrency: observed {max} concurrent calls (expected <= 2)"
1157        );
1158
1159        // All 3 chunks should have been called
1160        let texts = translator.received.lock().unwrap();
1161        assert_eq!(texts.len(), 3);
1162    }
1163
1164    #[tokio::test]
1165    async fn sequential_mode_is_deterministic() {
1166        // With max_concurrent=1, received_texts() must be in spawn order.
1167        let mut lines = vec![
1168            "[Script Info]".to_string(),
1169            "Title: Seq".to_string(),
1170            "ScriptType: v4.00+".to_string(),
1171            "".to_string(),
1172            "[V4+ Styles]".to_string(),
1173            "Format: Name, Fontname, Fontsize".to_string(),
1174            "Style: Default,Arial,20".to_string(),
1175            "".to_string(),
1176            "[Events]".to_string(),
1177            "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
1178                .to_string(),
1179        ];
1180        for i in 1..=500 {
1181            lines.push(format!(
1182                "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,seq line {i}",
1183                i, i + 1,
1184            ));
1185        }
1186        let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
1187        let translator = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
1188
1189        translate_ass(
1190            ass,
1191            "pt-BR",
1192            1, // sequential
1193            translator.clone() as Arc<dyn Translator>,
1194        )
1195        .await
1196        .unwrap();
1197
1198        let texts = translator.received_texts();
1199        assert_eq!(texts.len(), 3);
1200
1201        // In sequential mode, chunk 1 should be called before chunk 2, etc.
1202        // Verify by checking that each text starts with the expected cue ID range.
1203        assert!(texts[0].starts_with("<1> "), "first chunk should start with <1>");
1204        assert!(texts[1].starts_with("<201> "), "second chunk should start with <201>");
1205        assert!(texts[2].starts_with("<401> "), "third chunk should start with <401>");
1206    }
1207}