1use std::collections::BTreeMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use tempfile::tempdir;
6use tokio::sync::Semaphore;
7
8use crate::error::{Result, SubtitleToolkitError};
9use crate::media::mkv::{
10 SubtitleFormat, discover_mkv_files, extract_subtitle, inspect_mkv, mux_subtitle_in_place,
11 select_subtitle_track,
12};
13use crate::retry::retry_async;
14use crate::subtitles::ass::AssSubtitle;
15use crate::subtitles::model::SubtitleDocument;
16use crate::subtitles::srt::SrtSubtitle;
17use crate::subtitles::vtt::VttSubtitle;
18use crate::subtitles::structured::{
19 apply_translation, chunk_document_by_lines, parse_numbered_text, reinject_tags, strip_tags,
20 to_numbered_text,
21};
22use crate::translation::{TranslationRequest, Translator};
23
24#[derive(Debug, Clone)]
26pub struct TranslateMkvOptions {
27 pub input: PathBuf,
29 pub target_language: String,
31 pub track_id: Option<u64>,
33 pub keep_temp: bool,
35 pub dry_run: bool,
37 pub resume: bool,
39 pub max_concurrent: usize,
41}
42
43pub async fn translate_mkv(
56 options: TranslateMkvOptions,
57 translator: Arc<dyn Translator>,
58) -> Result<()> {
59 let files = discover_mkv_files(&options.input).await?;
60
61 let progress_path = progress_file_path(&options.input);
62 let mut completed: Vec<String> = if options.resume && progress_path.exists() {
63 let data = tokio::fs::read_to_string(&progress_path).await?;
64 serde_json::from_str(&data).unwrap_or_default()
65 } else {
66 Vec::new()
67 };
68
69 let total = files.len();
70 for (i, file) in files.into_iter().enumerate() {
71 let file_str = file.to_string_lossy().to_string();
72
73 if options.resume && completed.contains(&file_str) {
74 eprintln!("[resume] skipping ({}/{}): {}", i + 1, total, file.display());
75 continue;
76 }
77
78 translate_one(file, &options, translator.clone()).await?;
79
80 if options.resume {
81 completed.push(file_str);
82 let json = serde_json::to_string_pretty(&completed)?;
83 tokio::fs::write(&progress_path, &json).await?;
84 eprintln!("[resume] progress saved ({}/{})", completed.len(), total);
85 }
86 }
87
88 if options.resume && progress_path.exists() {
89 tokio::fs::remove_file(&progress_path).await?;
90 }
91
92 Ok(())
93}
94
95async fn translate_one(
96 file: PathBuf,
97 options: &TranslateMkvOptions,
98 translator: Arc<dyn Translator>,
99) -> Result<()> {
100 let info = inspect_mkv(&file).await?;
101 let (track, format) = select_subtitle_track(&info, options.track_id)
102 .ok_or_else(|| SubtitleToolkitError::NoSubtitleTrack { path: file.clone() })?;
103
104 let ext = match format {
105 SubtitleFormat::Ass => "ass",
106 SubtitleFormat::Srt => "srt",
107 SubtitleFormat::Vtt => "vtt",
108 };
109
110 let temp_dir = tempdir()?;
111 let extracted_path = temp_dir.path().join(format!("source.{ext}"));
112 let translated_path = temp_dir.path().join(format!("translated.{ext}"));
113
114 let format_label = match format {
115 SubtitleFormat::Ass => "ASS",
116 SubtitleFormat::Srt => "SRT",
117 SubtitleFormat::Vtt => "VTT",
118 };
119 eprintln!("[translate] {} ({})", file.display(), format_label);
120 extract_subtitle(&file, track.id, &extracted_path).await?;
121
122 let source = tokio::fs::read_to_string(&extracted_path).await?;
123
124 let rendered = match format {
125 SubtitleFormat::Ass => {
126 let ass = AssSubtitle::parse(&source)?;
127 if options.dry_run {
128 let summary = dry_run_summary(ass.document(), &options.target_language);
129 println!("[dry-run] {}: {}", file.display(), summary);
130 return Ok(());
131 }
132 translate_ass(
133 ass,
134 &options.target_language,
135 options.max_concurrent,
136 translator,
137 )
138 .await?
139 .render()
140 }
141 SubtitleFormat::Srt => {
142 let srt = SrtSubtitle::parse(&source)?;
143 if options.dry_run {
144 let summary = dry_run_summary(srt.document(), &options.target_language);
145 println!("[dry-run] {}: {}", file.display(), summary);
146 return Ok(());
147 }
148 translate_srt(
149 srt,
150 &options.target_language,
151 options.max_concurrent,
152 translator,
153 )
154 .await?
155 .render()
156 }
157 SubtitleFormat::Vtt => {
158 let vtt = VttSubtitle::parse(&source)?;
159 if options.dry_run {
160 let summary = dry_run_summary(vtt.document(), &options.target_language);
161 println!("[dry-run] {}: {}", file.display(), summary);
162 return Ok(());
163 }
164 translate_vtt(
165 vtt,
166 &options.target_language,
167 options.max_concurrent,
168 translator,
169 )
170 .await?
171 .render()
172 }
173 };
174
175 eprintln!("[translate] muxing translated subtitle");
176 tokio::fs::write(&translated_path, &rendered).await?;
177 mux_subtitle_in_place(&file, track.id, &translated_path, &options.target_language).await?;
178 eprintln!("[translate] done: {}", file.display());
179
180 if options.keep_temp {
181 let persisted = file.with_extension("psyche-subtitle-toolkit-temp");
182 tokio::fs::create_dir_all(&persisted).await?;
183 tokio::fs::copy(&extracted_path, persisted.join(format!("source.{ext}"))).await?;
184 tokio::fs::copy(&translated_path, persisted.join(format!("translated.{ext}"))).await?;
185 }
186
187 Ok(())
188}
189
190fn progress_file_path(input: &std::path::Path) -> PathBuf {
191 let dir = if input.is_dir() {
192 input.to_path_buf()
193 } else if input.extension().is_some() {
194 input.parent().unwrap_or(input).to_path_buf()
195 } else {
196 input.to_path_buf()
197 };
198 dir.join(".psyche-subtitle-toolkit-progress.json")
199}
200
201pub fn dry_run_summary(doc: &SubtitleDocument, target_language: &str) -> String {
203 let (clean_doc, _) = strip_tags(doc);
204 let chunks = chunk_document_by_lines(&clean_doc, 200);
205 let cue_count = clean_doc.cues.len();
206 let total_chars: usize = clean_doc.cues.iter().map(|c| c.text.len()).sum();
207 format!(
208 "{} cues, {} chars, {} chunk(s) → {}",
209 cue_count,
210 total_chars,
211 chunks.len(),
212 target_language,
213 )
214}
215
216pub async fn translate_ass(
227 mut ass: AssSubtitle,
228 target_language: &str,
229 max_concurrent: usize,
230 translator: Arc<dyn Translator>,
231) -> Result<AssSubtitle> {
232 let (mut clean_doc, tag_map) = strip_tags(ass.document());
233 clean_doc = translate_document(clean_doc, target_language, max_concurrent, translator).await?;
234 reinject_tags(&mut clean_doc, &tag_map);
235 *ass.document_mut() = clean_doc;
236 Ok(ass)
237}
238
239pub async fn translate_srt(
248 mut srt: SrtSubtitle,
249 target_language: &str,
250 max_concurrent: usize,
251 translator: Arc<dyn Translator>,
252) -> Result<SrtSubtitle> {
253 let clean_doc = translate_document(
254 srt.document().clone(),
255 target_language,
256 max_concurrent,
257 translator,
258 )
259 .await?;
260 *srt.document_mut() = clean_doc;
261 Ok(srt)
262}
263
264pub async fn translate_vtt(
273 mut vtt: VttSubtitle,
274 target_language: &str,
275 max_concurrent: usize,
276 translator: Arc<dyn Translator>,
277) -> Result<VttSubtitle> {
278 let clean_doc = translate_document(
279 vtt.document().clone(),
280 target_language,
281 max_concurrent,
282 translator,
283 )
284 .await?;
285 *vtt.document_mut() = clean_doc;
286 Ok(vtt)
287}
288
289async fn translate_document(
291 mut doc: SubtitleDocument,
292 target_language: &str,
293 max_concurrent: usize,
294 translator: Arc<dyn Translator>,
295) -> Result<SubtitleDocument> {
296 let chunks = chunk_document_by_lines(&doc, 200);
297 let chunk_count = chunks.len();
298 let cue_count = doc.cues.len();
299 let total_chars: usize = doc.cues.iter().map(|c| c.text.len()).sum();
300 eprintln!(
301 "[translate] {} cues, {} chars, {} chunk(s), {} concurrent",
302 cue_count, total_chars, chunk_count, max_concurrent,
303 );
304
305 let semaphore = Arc::new(Semaphore::new(max_concurrent));
306 let mut join_set = tokio::task::JoinSet::new();
307
308 for (i, chunk) in chunks.into_iter().enumerate() {
309 if chunk_count > 1 {
310 let chunk_chars: usize = chunk.cues.iter().map(|c| c.text.len()).sum();
311 eprintln!(
312 "[translate] chunk {}/{}: {} cues, {} chars",
313 i + 1,
314 chunk_count,
315 chunk.cues.len(),
316 chunk_chars,
317 );
318 }
319 let numbered = to_numbered_text(&chunk);
320 let ids: Vec<usize> = chunk.cues.iter().map(|cue| cue.id).collect();
321 let permit = semaphore
322 .clone()
323 .acquire_owned()
324 .await
325 .map_err(|e| SubtitleToolkitError::Translation {
326 provider: "pipeline",
327 message: format!("semaphore closed: {e}"),
328 })?;
329 let translator = translator.clone();
330 let target = target_language.to_string();
331
332 join_set.spawn(async move {
333 let _permit = permit;
334 let numbered_clone = numbered.clone();
335 let ids_clone = ids.clone();
336 let result = retry_async(3, || {
337 let numbered = numbered_clone.clone();
338 let ids = ids_clone.clone();
339 let translator = translator.clone();
340 let target = target.clone();
341 async move {
342 let translated_text = translator
343 .translate(TranslationRequest {
344 source_text: &numbered,
345 target_language: &target,
346 })
347 .await?;
348 parse_numbered_text(&translated_text, &ids)
349 }
350 })
351 .await;
352 (i, result)
353 });
354 }
355
356 let mut all_translated = BTreeMap::new();
357 let mut results: Vec<(usize, Result<BTreeMap<usize, String>>)> = Vec::new();
358 while let Some(result) = join_set.join_next().await {
359 let (i, outcome) = result.map_err(|e| SubtitleToolkitError::Translation {
360 provider: "pipeline",
361 message: format!("task panicked: {e}"),
362 })?;
363 results.push((i, outcome));
364 }
365 results.sort_by_key(|(i, _)| *i);
366 for (_, result) in results {
367 all_translated.extend(result?);
368 }
369
370 apply_translation(&mut doc, all_translated);
371 Ok(doc)
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::error::SubtitleToolkitError;
378 use crate::subtitles::ass::AssSubtitle;
379 use crate::translation::{TranslationRequest, Translator};
380 use std::sync::Mutex;
381
382 struct FakeTranslator {
388 received: Mutex<Vec<String>>,
390 responses: std::collections::HashMap<String, String>,
392 error: Option<String>,
394 sequential: Mutex<Vec<String>>,
397 }
398
399 impl FakeTranslator {
400 fn new(responses: std::collections::HashMap<String, String>) -> Self {
401 Self {
402 received: Mutex::new(Vec::new()),
403 responses,
404 error: None,
405 sequential: Mutex::new(Vec::new()),
406 }
407 }
408
409 fn with_error(message: &str) -> Self {
410 Self {
411 received: Mutex::new(Vec::new()),
412 responses: std::collections::HashMap::new(),
413 error: Some(message.to_string()),
414 sequential: Mutex::new(Vec::new()),
415 }
416 }
417
418 fn with_sequential_responses(responses: Vec<String>) -> Self {
419 Self {
420 received: Mutex::new(Vec::new()),
421 responses: std::collections::HashMap::new(),
422 error: None,
423 sequential: Mutex::new(responses),
424 }
425 }
426
427 fn received_texts(&self) -> Vec<String> {
428 self.received.lock().unwrap().clone()
429 }
430 }
431
432 #[async_trait::async_trait]
433 impl Translator for FakeTranslator {
434 async fn translate(&self, request: TranslationRequest<'_>) -> crate::error::Result<String> {
435 self.received
436 .lock()
437 .unwrap()
438 .push(request.source_text.to_string());
439
440 if let Some(msg) = &self.error {
441 return Err(SubtitleToolkitError::Translation {
442 provider: "fake",
443 message: msg.clone(),
444 });
445 }
446
447 {
449 let mut seq = self.sequential.lock().unwrap();
450 if !seq.is_empty() {
451 return Ok(seq.remove(0));
452 }
453 }
454
455 Ok(self
456 .responses
457 .get(request.source_text)
458 .cloned()
459 .unwrap_or_else(|| request.source_text.to_string()))
460 }
461 }
462
463 const SIMPLE_ASS: &str = r"[Script Info]
464Title: Test
465ScriptType: v4.00+
466
467[V4+ Styles]
468Format: Name, Fontname, Fontsize
469Style: Default,Arial,20
470
471[Events]
472Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
473Dialogue: 0,0:00:01.00,0:00:02.00,Default,,0,0,0,,Hello world
474Dialogue: 0,0:00:03.00,0:00:04.00,Default,,0,0,0,,Goodbye world
475";
476
477 const ASS_WITH_TAGS: &str = r"[Script Info]
478Title: Test Tags
479ScriptType: v4.00+
480
481[V4+ Styles]
482Format: Name, Fontname, Fontsize
483Style: Default,Arial,20
484
485[Events]
486Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
487Dialogue: 0,0:00:01.00,0:00:02.00,Default,,0,0,0,,{\pos(857.6,122.4)}{\an7}Status line
488Dialogue: 0,0:00:03.00,0:00:04.00,Default,,0,0,0,,Normal text
489";
490
491 #[tokio::test]
492 async fn pipeline_translates_dialogue_and_preserves_structure() {
493 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
494
495 let mut responses = std::collections::HashMap::new();
496 responses.insert("<1> Hello world\n<2> Goodbye world".to_string(), "<1> Olá mundo\n<2> Adeus mundo".to_string());
497
498 let translator = Arc::new(FakeTranslator::new(responses));
499 let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
500 let rendered = result.render();
501
502 assert!(rendered.contains("Olá mundo"));
503 assert!(rendered.contains("Adeus mundo"));
504 assert!(!rendered.contains("Hello world"));
505 assert!(!rendered.contains("Goodbye world"));
506
507 assert!(rendered.contains("[Script Info]"));
509 assert!(rendered.contains("[V4+ Styles]"));
510 assert!(rendered.contains("[Events]"));
511 }
512
513 #[tokio::test]
514 async fn pipeline_passes_numbered_text_and_target_language_to_translator() {
515 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
516
517 let mut responses = std::collections::HashMap::new();
518 responses.insert("<1> Hello world\n<2> Goodbye world".to_string(), "<1> translated1\n<2> translated2".to_string());
519
520 let translator = Arc::new(FakeTranslator::new(responses));
521 translate_ass(ass, "ja", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
522
523 let texts = translator.received_texts();
524 assert_eq!(texts.len(), 1);
525 assert_eq!(texts[0], "<1> Hello world\n<2> Goodbye world");
526 }
527
528 #[tokio::test]
529 async fn pipeline_strips_and_reinjects_override_tags() {
530 let ass = AssSubtitle::parse(ASS_WITH_TAGS).unwrap();
531
532 let mut responses = std::collections::HashMap::new();
534 responses.insert(
535 "<1> Status line\n<2> Normal text".to_string(),
536 "<1> Linha de status\n<2> Texto normal".to_string(),
537 );
538
539 let translator = Arc::new(FakeTranslator::new(responses));
540 let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
541 let rendered = result.render();
542
543 assert!(rendered.contains(r"{\pos(857.6,122.4)}{\an7}Linha de status"));
545 assert!(rendered.contains("Texto normal"));
547 assert!(!rendered.contains("Normal text"));
548
549 let texts = translator.received_texts();
551 assert!(!texts[0].contains(r"{\pos"));
552 assert!(!texts[0].contains(r"{\an7}"));
553 }
554
555 #[tokio::test]
556 async fn pipeline_chunks_large_documents() {
557 let mut lines = vec![
559 "[Script Info]".to_string(),
560 "Title: Big".to_string(),
561 "ScriptType: v4.00+".to_string(),
562 "".to_string(),
563 "[V4+ Styles]".to_string(),
564 "Format: Name, Fontname, Fontsize".to_string(),
565 "Style: Default,Arial,20".to_string(),
566 "".to_string(),
567 "[Events]".to_string(),
568 "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
569 .to_string(),
570 ];
571
572 for i in 1..=300 {
573 lines.push(format!(
574 "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,subtitle line {i}",
575 i, i + 1,
576 ));
577 }
578
579 let ass_content = lines.join("\n");
580
581 let ass = AssSubtitle::parse(&ass_content).unwrap();
583 let (clean_doc, _) = crate::subtitles::structured::strip_tags(ass.document());
584 let chunks = crate::subtitles::structured::chunk_document_by_lines(&clean_doc, 200);
585 assert_eq!(chunks.len(), 2, "300 cues should produce 2 chunks at 200 lines");
586 assert_eq!(chunks[0].cues.len(), 200);
587 assert_eq!(chunks[1].cues.len(), 100);
588
589 let translator = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
591
592 let ass = AssSubtitle::parse(&ass_content).unwrap();
593 let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
594 let rendered = result.render();
595
596 for i in 1..=300 {
598 assert!(
599 rendered.contains(&format!("subtitle line {i}")),
600 "missing cue {i} in rendered output"
601 );
602 }
603
604 let texts = translator.received_texts();
606 assert_eq!(texts.len(), 2);
607
608 let mut all_ids: Vec<usize> = Vec::new();
610 for text in &texts {
611 for line in text.lines() {
612 if let Some(start) = line.find('<')
613 && let Some(end) = line[start + 1..].find('>')
614 && let Ok(id) = line[start + 1..start + 1 + end].parse::<usize>()
615 {
616 all_ids.push(id);
617 }
618 }
619 }
620 all_ids.sort();
621 assert_eq!(all_ids, (1..=300).collect::<Vec<_>>());
622 }
623
624 #[tokio::test]
625 async fn pipeline_propagates_translator_error() {
626 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
627 let translator = FakeTranslator::with_error("rate limit exceeded");
628
629 let err = translate_ass(ass, "pt-BR", 1, Arc::new(translator) as Arc<dyn Translator>).await.unwrap_err();
630
631 assert!(err.to_string().contains("fake"));
632 assert!(err.to_string().contains("rate limit exceeded"));
633 }
634
635 #[tokio::test]
636 async fn pipeline_rejects_incomplete_translation() {
637 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
638
639 let mut responses = std::collections::HashMap::new();
641 responses.insert(
642 "<1> Hello world\n<2> Goodbye world".to_string(),
643 "<1> Olá mundo".to_string(), );
645
646 let translator = FakeTranslator::new(responses);
647 let err = translate_ass(ass, "pt-BR", 1, Arc::new(translator) as Arc<dyn Translator>).await.unwrap_err();
648
649 assert!(err.to_string().contains("missing id <2>"));
650 }
651
652 #[tokio::test]
653 async fn pipeline_handles_multiline_cues() {
654 let ass_content = r"[Script Info]
655Title: Multiline
656ScriptType: v4.00+
657
658[V4+ Styles]
659Format: Name, Fontname, Fontsize
660Style: Default,Arial,20
661
662[Events]
663Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
664Dialogue: 0,0:00:01.00,0:00:02.00,Default,,0,0,0,,First line\NSecond line
665";
666
667 let ass = AssSubtitle::parse(ass_content).unwrap();
668
669 let mut responses = std::collections::HashMap::new();
670 responses.insert(
671 "<1> First line\\NSecond line".to_string(),
672 "<1> Primeira linha\\NSegunda linha".to_string(),
673 );
674
675 let translator = Arc::new(FakeTranslator::new(responses));
676 let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
677 let rendered = result.render();
678
679 assert!(rendered.contains("Primeira linha"));
680 assert!(rendered.contains("Segunda linha"));
681 }
682
683 #[tokio::test]
684 async fn pipeline_translates_basic_document() {
685 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
686
687 let mut responses = std::collections::HashMap::new();
688 responses.insert(
689 "<1> Hello world\n<2> Goodbye world".to_string(),
690 "<1> translated1\n<2> translated2".to_string(),
691 );
692
693 let translator = Arc::new(FakeTranslator::new(responses));
694 translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>)
695 .await
696 .unwrap();
697
698 let texts = translator.received_texts();
699 assert_eq!(texts.len(), 1);
700 }
701
702 #[test]
703 fn dry_run_summary_reports_cues_chars_chunks() {
704 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
705 let summary = dry_run_summary(ass.document(), "pt-BR");
706
707 assert!(summary.contains("2 cues"), "expected '2 cues' in: {summary}");
708 assert!(summary.contains("1 chunk(s)"), "expected '1 chunk(s)' in: {summary}");
709 assert!(summary.contains("→ pt-BR"), "expected '→ pt-BR' in: {summary}");
710 }
711
712 #[test]
713 fn dry_run_summary_counts_chars() {
714 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
715 let summary = dry_run_summary(ass.document(), "en");
716
717 assert!(summary.contains("24 chars"), "expected '24 chars' in: {summary}");
719 }
720
721 #[test]
722 fn dry_run_summary_handles_empty_document() {
723 let ass_content = r"[Script Info]
724Title: Empty
725ScriptType: v4.00+
726
727[V4+ Styles]
728Format: Name, Fontname, Fontsize
729Style: Default,Arial,20
730
731[Events]
732Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
733";
734 let ass = AssSubtitle::parse(ass_content).unwrap();
735 let summary = dry_run_summary(ass.document(), "de");
736
737 assert!(summary.contains("0 cues"), "expected '0 cues' in: {summary}");
738 assert!(summary.contains("0 chars"), "expected '0 chars' in: {summary}");
739 assert!(summary.contains("0 chunk(s)"), "expected '0 chunk(s)' in: {summary}");
740 }
741
742 #[test]
743 fn dry_run_summary_splits_large_documents() {
744 let mut lines = vec![
746 "[Script Info]".to_string(),
747 "Title: Big".to_string(),
748 "ScriptType: v4.00+".to_string(),
749 "".to_string(),
750 "[V4+ Styles]".to_string(),
751 "Format: Name, Fontname, Fontsize".to_string(),
752 "Style: Default,Arial,20".to_string(),
753 "".to_string(),
754 "[Events]".to_string(),
755 "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
756 .to_string(),
757 ];
758 for i in 1..=300 {
759 lines.push(format!(
760 "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,This is subtitle line number {i} with enough text to fill space",
761 i, i + 1,
762 ));
763 }
764 let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
765 let summary = dry_run_summary(ass.document(), "pt-BR");
766
767 assert!(summary.contains("300 cues"), "expected '300 cues' in: {summary}");
768 assert!(
769 summary.contains("2 chunk(s)"),
770 "expected '2 chunk(s)' in: {summary}"
771 );
772 }
773
774 #[tokio::test]
775 async fn pipeline_retries_chunk_on_malformed_output() {
776 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
777
778 let translator = Arc::new(FakeTranslator::with_sequential_responses(vec![
781 "<1> Olá mundo".to_string(), "<1> Olá mundo\n<2> Adeus mundo".to_string(), ]));
784
785 let result = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>).await.unwrap();
786 let rendered = result.render();
787
788 assert!(rendered.contains("Olá mundo"));
789 assert!(rendered.contains("Adeus mundo"));
790
791 let texts = translator.received_texts();
793 assert_eq!(texts.len(), 2);
794 }
795
796 #[tokio::test]
797 async fn pipeline_gives_up_after_repeated_malformed_output() {
798 let ass = AssSubtitle::parse(SIMPLE_ASS).unwrap();
799
800 let translator = Arc::new(FakeTranslator::with_sequential_responses(vec![
802 "<1> Olá mundo".to_string(),
803 "<1> Olá mundo".to_string(),
804 "<1> Olá mundo".to_string(),
805 "<1> Olá mundo".to_string(), ]));
807
808 let err = translate_ass(ass, "pt-BR", 1, translator.clone() as Arc<dyn Translator>)
809 .await
810 .unwrap_err();
811
812 assert!(err.to_string().contains("missing id <2>"));
813
814 let texts = translator.received_texts();
816 assert_eq!(texts.len(), 4);
817 }
818
819 #[test]
820 fn progress_file_path_for_directory() {
821 let path = progress_file_path(std::path::Path::new("/media/anime"));
822 assert_eq!(
823 path,
824 std::path::PathBuf::from("/media/anime/.psyche-subtitle-toolkit-progress.json")
825 );
826 }
827
828 #[test]
829 fn progress_file_path_for_file() {
830 let path = progress_file_path(std::path::Path::new("/media/anime/episode.mkv"));
831 assert_eq!(
832 path,
833 std::path::PathBuf::from("/media/anime/.psyche-subtitle-toolkit-progress.json")
834 );
835 }
836
837 #[tokio::test]
838 async fn progress_file_roundtrip() {
839 let dir = tempfile::tempdir().unwrap();
840 let progress_path = dir.path().join(".psyche-subtitle-toolkit-progress.json");
841
842 let completed = vec![
843 "/media/anime/ep1.mkv".to_string(),
844 "/media/anime/ep2.mkv".to_string(),
845 ];
846 let json = serde_json::to_string_pretty(&completed).unwrap();
847 tokio::fs::write(&progress_path, &json).await.unwrap();
848
849 let data = tokio::fs::read_to_string(&progress_path).await.unwrap();
850 let loaded: Vec<String> = serde_json::from_str(&data).unwrap();
851
852 assert_eq!(loaded, completed);
853 }
854
855 #[tokio::test]
856 async fn progress_file_handles_missing_file() {
857 let dir = tempfile::tempdir().unwrap();
858 let progress_path = dir.path().join(".psyche-subtitle-toolkit-progress.json");
859
860 let completed: Vec<String> = if progress_path.exists() {
862 let data = tokio::fs::read_to_string(&progress_path).await.unwrap();
863 serde_json::from_str(&data).unwrap_or_default()
864 } else {
865 Vec::new()
866 };
867
868 assert!(completed.is_empty());
869 }
870
871 #[tokio::test]
872 async fn progress_file_handles_corrupted_json() {
873 let dir = tempfile::tempdir().unwrap();
874 let progress_path = dir.path().join(".psyche-subtitle-toolkit-progress.json");
875
876 tokio::fs::write(&progress_path, "not valid json")
877 .await
878 .unwrap();
879
880 let data = tokio::fs::read_to_string(&progress_path).await.unwrap();
881 let completed: Vec<String> = serde_json::from_str(&data).unwrap_or_default();
882
883 assert!(completed.is_empty());
885 }
886
887 #[tokio::test]
888 async fn pipeline_translates_concurrently() {
889 let mut lines = vec![
891 "[Script Info]".to_string(),
892 "Title: Concurrent".to_string(),
893 "ScriptType: v4.00+".to_string(),
894 "".to_string(),
895 "[V4+ Styles]".to_string(),
896 "Format: Name, Fontname, Fontsize".to_string(),
897 "Style: Default,Arial,20".to_string(),
898 "".to_string(),
899 "[Events]".to_string(),
900 "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
901 .to_string(),
902 ];
903 for i in 1..=300 {
904 lines.push(format!(
905 "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,line {i}",
906 i, i + 1,
907 ));
908 }
909 let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
910 let translator = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
911
912 let result = translate_ass(
914 ass,
915 "pt-BR",
916 2,
917 translator.clone() as Arc<dyn Translator>,
918 )
919 .await
920 .unwrap();
921
922 let rendered = result.render();
923 for i in 1..=300 {
924 assert!(rendered.contains(&format!("line {i}")), "missing cue {i}");
925 }
926
927 let texts = translator.received_texts();
929 assert_eq!(texts.len(), 2);
930 }
931
932 #[tokio::test]
933 async fn concurrent_translation_preserves_all_cues() {
934 let mut lines = vec![
936 "[Script Info]".to_string(),
937 "Title: Stress".to_string(),
938 "ScriptType: v4.00+".to_string(),
939 "".to_string(),
940 "[V4+ Styles]".to_string(),
941 "Format: Name, Fontname, Fontsize".to_string(),
942 "Style: Default,Arial,20".to_string(),
943 "".to_string(),
944 "[Events]".to_string(),
945 "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
946 .to_string(),
947 ];
948 for i in 1..=1000 {
949 lines.push(format!(
950 "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,stress line {i}",
951 i, i + 1,
952 ));
953 }
954 let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
955 let translator = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
956
957 let result = translate_ass(
958 ass,
959 "pt-BR",
960 5,
961 translator.clone() as Arc<dyn Translator>,
962 )
963 .await
964 .unwrap();
965
966 let rendered = result.render();
967 for i in 1..=1000 {
969 assert!(
970 rendered.contains(&format!("stress line {i}")),
971 "missing cue {i} under concurrent translation"
972 );
973 }
974
975 let texts = translator.received_texts();
977 assert_eq!(texts.len(), 5, "expected 5 chunk calls, got {}", texts.len());
978 }
979
980 #[tokio::test]
981 async fn concurrent_translation_output_is_deterministic() {
982 let make_ass = || {
985 let mut lines = vec![
986 "[Script Info]".to_string(),
987 "Title: Deterministic".to_string(),
988 "ScriptType: v4.00+".to_string(),
989 "".to_string(),
990 "[V4+ Styles]".to_string(),
991 "Format: Name, Fontname, Fontsize".to_string(),
992 "Style: Default,Arial,20".to_string(),
993 "".to_string(),
994 "[Events]".to_string(),
995 "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
996 .to_string(),
997 ];
998 for i in 1..=500 {
999 lines.push(format!(
1000 "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,det line {i}",
1001 i, i + 1,
1002 ));
1003 }
1004 AssSubtitle::parse(&lines.join("\n")).unwrap()
1005 };
1006
1007 let t1 = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
1008 let r1 = translate_ass(
1009 make_ass(),
1010 "pt-BR",
1011 3,
1012 t1.clone() as Arc<dyn Translator>,
1013 )
1014 .await
1015 .unwrap();
1016
1017 let t2 = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
1018 let r2 = translate_ass(
1019 make_ass(),
1020 "pt-BR",
1021 3,
1022 t2.clone() as Arc<dyn Translator>,
1023 )
1024 .await
1025 .unwrap();
1026
1027 assert_eq!(r1.render(), r2.render(), "concurrent output is non-deterministic");
1028 }
1029
1030 #[tokio::test]
1031 async fn concurrent_error_propagates_correctly() {
1032 let mut lines = vec![
1035 "[Script Info]".to_string(),
1036 "Title: ErrProp".to_string(),
1037 "ScriptType: v4.00+".to_string(),
1038 "".to_string(),
1039 "[V4+ Styles]".to_string(),
1040 "Format: Name, Fontname, Fontsize".to_string(),
1041 "Style: Default,Arial,20".to_string(),
1042 "".to_string(),
1043 "[Events]".to_string(),
1044 "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
1045 .to_string(),
1046 ];
1047 for i in 1..=400 {
1048 lines.push(format!(
1049 "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,line {i}",
1050 i, i + 1,
1051 ));
1052 }
1053 let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
1054
1055 let translator = Arc::new(FakeTranslator::with_error("provider down"));
1057
1058 let err = translate_ass(
1059 ass,
1060 "pt-BR",
1061 3,
1062 translator.clone() as Arc<dyn Translator>,
1063 )
1064 .await
1065 .unwrap_err();
1066
1067 assert!(err.to_string().contains("provider down"));
1068 }
1069
1070 struct ConcurrencyTrackingTranslator {
1073 active: std::sync::atomic::AtomicU32,
1074 max_observed: std::sync::atomic::AtomicU32,
1075 received: Mutex<Vec<String>>,
1076 }
1077
1078 impl ConcurrencyTrackingTranslator {
1079 fn new() -> Self {
1080 Self {
1081 active: std::sync::atomic::AtomicU32::new(0),
1082 max_observed: std::sync::atomic::AtomicU32::new(0),
1083 received: Mutex::new(Vec::new()),
1084 }
1085 }
1086
1087 fn max_concurrent_calls(&self) -> u32 {
1088 self.max_observed.load(std::sync::atomic::Ordering::SeqCst)
1089 }
1090 }
1091
1092 #[async_trait::async_trait]
1093 impl Translator for ConcurrencyTrackingTranslator {
1094 async fn translate(&self, request: TranslationRequest<'_>) -> crate::error::Result<String> {
1095 self.received.lock().unwrap().push(request.source_text.to_string());
1096
1097 let current = self.active.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
1098 self.max_observed.fetch_max(current, std::sync::atomic::Ordering::SeqCst);
1100
1101 tokio::task::yield_now().await;
1103
1104 self.active.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
1105
1106 Ok(request.source_text.to_string())
1108 }
1109 }
1110
1111 #[tokio::test]
1112 async fn semaphore_bounds_concurrency() {
1113 let mut lines = vec![
1116 "[Script Info]".to_string(),
1117 "Title: Semaphore".to_string(),
1118 "ScriptType: v4.00+".to_string(),
1119 "".to_string(),
1120 "[V4+ Styles]".to_string(),
1121 "Format: Name, Fontname, Fontsize".to_string(),
1122 "Style: Default,Arial,20".to_string(),
1123 "".to_string(),
1124 "[Events]".to_string(),
1125 "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
1126 .to_string(),
1127 ];
1128 for i in 1..=600 {
1129 lines.push(format!(
1130 "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,sem line {i}",
1131 i, i + 1,
1132 ));
1133 }
1134 let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
1135 let translator = Arc::new(ConcurrencyTrackingTranslator::new());
1136
1137 let result = translate_ass(
1138 ass,
1139 "pt-BR",
1140 2, translator.clone() as Arc<dyn Translator>,
1142 )
1143 .await
1144 .unwrap();
1145
1146 let rendered = result.render();
1148 for i in 1..=600 {
1149 assert!(rendered.contains(&format!("sem line {i}")), "missing cue {i}");
1150 }
1151
1152 let max = translator.max_concurrent_calls();
1154 assert!(
1155 max <= 2,
1156 "semaphore failed to bound concurrency: observed {max} concurrent calls (expected <= 2)"
1157 );
1158
1159 let texts = translator.received.lock().unwrap();
1161 assert_eq!(texts.len(), 3);
1162 }
1163
1164 #[tokio::test]
1165 async fn sequential_mode_is_deterministic() {
1166 let mut lines = vec![
1168 "[Script Info]".to_string(),
1169 "Title: Seq".to_string(),
1170 "ScriptType: v4.00+".to_string(),
1171 "".to_string(),
1172 "[V4+ Styles]".to_string(),
1173 "Format: Name, Fontname, Fontsize".to_string(),
1174 "Style: Default,Arial,20".to_string(),
1175 "".to_string(),
1176 "[Events]".to_string(),
1177 "Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text"
1178 .to_string(),
1179 ];
1180 for i in 1..=500 {
1181 lines.push(format!(
1182 "Dialogue: 0,0:00:{:02}.00,0:00:{:02}.00,Default,,0,0,0,,seq line {i}",
1183 i, i + 1,
1184 ));
1185 }
1186 let ass = AssSubtitle::parse(&lines.join("\n")).unwrap();
1187 let translator = Arc::new(FakeTranslator::new(std::collections::HashMap::new()));
1188
1189 translate_ass(
1190 ass,
1191 "pt-BR",
1192 1, translator.clone() as Arc<dyn Translator>,
1194 )
1195 .await
1196 .unwrap();
1197
1198 let texts = translator.received_texts();
1199 assert_eq!(texts.len(), 3);
1200
1201 assert!(texts[0].starts_with("<1> "), "first chunk should start with <1>");
1204 assert!(texts[1].starts_with("<201> "), "second chunk should start with <201>");
1205 assert!(texts[2].starts_with("<401> "), "third chunk should start with <401>");
1206 }
1207}