use std::collections::{HashMap, HashSet};
use futures::stream::{self, StreamExt, TryStreamExt};
use crate::domain::{SubtitleDocument, TranslationOptions, TranslationResult};
use crate::error::TranslatorError;
use crate::formats;
use crate::providers::{TranslatedItem, TranslationBatch, TranslationProvider};
use super::chunker;
pub async fn translate_document(
provider: &dyn TranslationProvider,
document: &SubtitleDocument,
options: &TranslationOptions,
) -> Result<TranslationResult, TranslatorError> {
options.validate()?;
let (classified_document, classification_report) =
formats::classify_document(document, &options.ass_classification_policy);
if classified_document.cues().is_empty() {
let rendered = formats::render_subtitle(&classified_document)?;
return Ok(TranslationResult::new(
classified_document,
rendered,
0,
classification_report,
));
}
let capabilities = provider.capabilities();
let batches = chunker::chunk_cues(classified_document.cues(), options, &capabilities);
if batches.is_empty() {
let rendered = formats::render_subtitle(&classified_document)?;
return Ok(TranslationResult::new(
classified_document,
rendered,
0,
classification_report,
));
}
let mut translations = HashMap::new();
let source_language = options.source_language.clone();
let target_language = options.target_language.clone();
let system_prompt = options.system_prompt.clone();
let responses = stream::iter(batches.iter().cloned().map(|batch| {
let source_language = source_language.clone();
let target_language = target_language.clone();
let system_prompt = system_prompt.clone();
async move {
let request = TranslationBatch {
source_language,
target_language,
system_prompt,
items: batch.clone(),
};
let response = provider.translate_batch(request).await?;
validate_batch_response(&batch, &response.items)?;
Ok::<Vec<TranslatedItem>, TranslatorError>(response.items)
}
}))
.buffer_unordered(options.max_parallel_batches)
.try_collect::<Vec<_>>()
.await?;
for batch_items in responses {
for item in batch_items {
translations.insert(item.id, item.text);
}
}
let translated_document = classified_document.translated_with(&translations)?;
let rendered = formats::render_subtitle(&translated_document)?;
Ok(TranslationResult::new(
translated_document,
rendered,
batches.len(),
classification_report,
))
}
fn validate_batch_response(
request_items: &[crate::providers::TranslationBatchItem],
response_items: &[TranslatedItem],
) -> Result<(), TranslatorError> {
if request_items.len() != response_items.len() {
return Err(TranslatorError::Validation(format!(
"provider returned {} items for a batch of {} cues",
response_items.len(),
request_items.len()
)));
}
let expected = request_items
.iter()
.map(|item| item.id.as_str())
.collect::<HashSet<_>>();
let actual = response_items
.iter()
.map(|item| item.id.as_str())
.collect::<HashSet<_>>();
if expected != actual {
return Err(TranslatorError::Validation(
"provider response cue IDs do not match the request".to_owned(),
));
}
if response_items
.iter()
.any(|item| item.text.trim().is_empty())
{
return Err(TranslatorError::Validation(
"provider returned an empty translated text".to_owned(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use tokio::time::{Duration, sleep};
use crate::domain::{
CueKind, SubtitleCue, SubtitleDocument, SubtitleFormat, TranslationOptions,
};
use crate::formats::ass;
use crate::providers::{
TranslatedItem, TranslationBatch, TranslationBatchOutput, TranslationProvider,
};
use super::translate_document;
struct EchoProvider;
#[async_trait]
impl TranslationProvider for EchoProvider {
fn name(&self) -> &str {
"echo"
}
async fn translate_batch(
&self,
batch: TranslationBatch,
) -> Result<TranslationBatchOutput, crate::error::TranslatorError> {
Ok(TranslationBatchOutput {
items: batch
.items
.into_iter()
.map(|item| TranslatedItem {
id: item.id,
text: format!("translated:{}", item.text),
})
.collect(),
})
}
}
#[tokio::test]
async fn translates_all_cues() {
let document = SubtitleDocument::from_parts(
SubtitleFormat::Srt,
vec![SubtitleCue::new(
"cue-1",
Some("1".to_owned()),
"00:00:00,000",
"00:00:01,000",
None,
"hello",
BTreeMap::new(),
)],
crate::domain::RenderPlan::Srt,
);
let options = TranslationOptions::default();
let result = translate_document(&EchoProvider, &document, &options)
.await
.expect("translation should succeed");
assert_eq!(result.document().cues()[0].text(), "translated:hello");
assert!(result.rendered().contains("translated:hello"));
let _ = Arc::new(EchoProvider);
}
struct RecordingProvider {
seen: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl TranslationProvider for RecordingProvider {
fn name(&self) -> &str {
"recording"
}
async fn translate_batch(
&self,
batch: TranslationBatch,
) -> Result<TranslationBatchOutput, crate::error::TranslatorError> {
let mut seen = self.seen.lock().expect("seen mutex should lock");
for item in &batch.items {
seen.push(item.text.clone());
}
drop(seen);
Ok(TranslationBatchOutput {
items: batch
.items
.into_iter()
.map(|item| TranslatedItem {
id: item.id,
text: format!("translated:{}", item.text),
})
.collect(),
})
}
}
#[tokio::test]
async fn preserves_non_translatable_ass_cues() {
let source = "[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\nDialogue: 0,0:00:01.00,0:00:03.00,Default,,0,0,0,,Hello there\nDialogue: 0,0:00:03.10,0:00:04.00,Default,,0,0,0,,{\\k20}Ka{\\k20}ra\nDialogue: 0,0:00:04.10,0:00:05.00,Lyrics,,0,0,0,,Shining star\n";
let document = ass::parse(source).expect("parse should succeed");
let seen = Arc::new(Mutex::new(Vec::new()));
let provider = RecordingProvider { seen: seen.clone() };
let result = translate_document(&provider, &document, &TranslationOptions::default())
.await
.expect("translation should succeed");
assert_eq!(result.document().cues()[0].kind(), CueKind::Dialogue);
assert_eq!(result.document().cues()[0].text(), "translated:Hello there");
assert_eq!(result.document().cues()[1].kind(), CueKind::Karaoke);
assert_eq!(result.document().cues()[1].text(), "{\\k20}Ka{\\k20}ra");
assert_eq!(result.document().cues()[2].kind(), CueKind::Song);
assert_eq!(result.document().cues()[2].text(), "Shining star");
assert_eq!(
seen.lock().expect("seen mutex should lock").as_slice(),
&["Hello there"]
);
assert_eq!(result.batches(), 1);
}
struct SlowProvider {
current: Arc<AtomicUsize>,
max_seen: Arc<AtomicUsize>,
}
#[async_trait]
impl TranslationProvider for SlowProvider {
fn name(&self) -> &str {
"slow"
}
async fn translate_batch(
&self,
batch: TranslationBatch,
) -> Result<TranslationBatchOutput, crate::error::TranslatorError> {
let current = self.current.fetch_add(1, Ordering::SeqCst) + 1;
let mut observed = self.max_seen.load(Ordering::SeqCst);
while current > observed {
match self.max_seen.compare_exchange(
observed,
current,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => break,
Err(next) => observed = next,
}
}
sleep(Duration::from_millis(40)).await;
self.current.fetch_sub(1, Ordering::SeqCst);
Ok(TranslationBatchOutput {
items: batch
.items
.into_iter()
.map(|item| TranslatedItem {
id: item.id,
text: format!("ok:{}", item.text),
})
.collect(),
})
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn translates_batches_in_parallel() {
let cues = (0..6)
.map(|index| {
SubtitleCue::new(
format!("cue-{index}"),
None,
format!("00:00:0{index},000"),
format!("00:00:0{},500", index + 1),
None,
format!("line-{index}"),
BTreeMap::new(),
)
})
.collect::<Vec<_>>();
let document =
SubtitleDocument::from_parts(SubtitleFormat::Srt, cues, crate::domain::RenderPlan::Srt);
let current = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let provider = SlowProvider {
current: current.clone(),
max_seen: max_seen.clone(),
};
let options = TranslationOptions {
max_batch_items: 1,
max_parallel_batches: 4,
..TranslationOptions::default()
};
let result = translate_document(&provider, &document, &options)
.await
.expect("translation should succeed");
assert_eq!(result.batches(), 6);
assert!(max_seen.load(Ordering::SeqCst) >= 2);
assert!(result.rendered().contains("ok:line-0"));
}
}