use anyhow::anyhow;
use async_openai::{config::OpenAIConfig, types::CreateChatCompletionRequest, Client};
use indicatif::ProgressBar;
use lazy_static::lazy_static;
use log::{debug, trace};
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::json;
use super::{
function_tool, function_tool_choice, retry_openai_request, system_message,
tool_call_response, user_message,
};
use crate::{
lang::Lang,
progress::default_progress_style,
srt::{Subtitle, SubtitleFile},
Result,
};
const MIN_CHUNK_SIZE: usize = 10;
const MAX_CHUNK_SIZE: usize = 15;
lazy_static! {
static ref REPORT_TRANSLATION_PARAMETERS_SCHEMA: serde_json::Value = json!({
"type": "object",
"properties": {
"lines": {
"type": "array",
"items": {
"type": "object",
"properties": {
"original": {
"type": "string"
},
"translation": {
"type": "string"
}
},
"required": [
"original",
"translation"
]
}
}
},
"required": [
"lines"
]
});
static ref SENTENCE_END: Regex =
Regex::new(r"[\p{Sentence_Terminal}]\s*$").unwrap();
}
pub async fn translate_subtitle_file(
file: &SubtitleFile,
to_lang: Lang,
) -> Result<SubtitleFile> {
let from_lang = file.detect_language().ok_or_else(|| {
anyhow!("Could not detect the language of the input subtitle file")
})?;
let mut sub_chunks = vec![];
let mut current_chunk = vec![];
for sub in &file.subtitles {
current_chunk.push(sub.clone());
let last_line = sub.lines.last().cloned().unwrap_or_else(|| "".to_owned());
if current_chunk.len() >= MIN_CHUNK_SIZE
&& (current_chunk.len() >= MAX_CHUNK_SIZE
|| SENTENCE_END.is_match(&last_line))
{
sub_chunks.push(current_chunk.clone());
current_chunk.clear();
}
}
if current_chunk.len() > 0 {
sub_chunks.push(current_chunk);
}
let progress = ProgressBar::new(file.subtitles.len() as u64);
progress.set_style(default_progress_style());
progress.set_prefix("📖 Translating");
progress.tick();
let client = Client::new();
let mut translated_subs = vec![];
for chunk in &sub_chunks {
let translated_lines = retry_openai_request(|| {
translate_chunk(&client, chunk, from_lang, to_lang)
})
.await?;
for (sub, translated) in chunk.iter().zip(translated_lines) {
let mut translated_sub = sub.clone();
translated_sub.lines =
vec![translated.translation.clone().ok_or_else(|| {
anyhow!(
"OpenAI did not return a translation for a line: {:?}",
translated.original
)
})?];
translated_subs.push(translated_sub);
}
progress.inc(chunk.len() as u64);
}
progress.finish();
Ok(SubtitleFile {
subtitles: translated_subs,
})
}
async fn translate_chunk(
client: &Client<OpenAIConfig>,
chunk: &[Subtitle],
from_lang: Lang,
to_lang: Lang,
) -> Result<Vec<LineTranslation>> {
let prompt = prompt_from_chunk(chunk, from_lang, to_lang)?;
debug!("OpenAI request (prompt): {}", prompt);
let req = CreateChatCompletionRequest {
model: "gpt-3.5-turbo".to_owned(),
messages: vec![
system_message("You are a subtitle translator helping language learners."),
user_message(prompt),
],
tools: Some(vec![function_tool(
"report_translations",
"Report the translations of the lines of dialog.",
&REPORT_TRANSLATION_PARAMETERS_SCHEMA,
)]),
tool_choice: Some(function_tool_choice("report_translations")),
..Default::default()
};
trace!("OpenAI request (full): {:?}", req);
let resp = client.chat().create(req).await?;
trace!("OpenAI response (full): {:?}", resp);
let args = tool_call_response::<ReportTranslationParameters>(
&resp,
"report_translations",
)?;
let translated_lines = args.lines;
if translated_lines.len() != chunk.len() {
return Err(anyhow!(
"OpenAI returned the wrong number of translations: {}",
translated_lines.len()
));
}
Ok(translated_lines)
}
fn prompt_from_chunk(
chunk: &[Subtitle],
from_lang: Lang,
to_lang: Lang,
) -> Result<String> {
let template = ReportTranslationParameters {
lines: chunk
.iter()
.map(LineTranslation::template_from_subtitle)
.collect(),
};
let json_template =
serde_json::to_string_pretty(&template).expect("failed to format JSON");
Ok(format!(
"Translate the following consecutive lines of dialog from {from} to {to}:
```json\n{template}```
Please call the function `report_translation` with your output.",
from = from_lang.english_names()?[0],
to = to_lang.english_names()?[0],
template = json_template,
))
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ReportTranslationParameters {
pub lines: Vec<LineTranslation>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LineTranslation {
pub original: String,
pub translation: Option<String>,
}
impl LineTranslation {
pub fn template_from_subtitle(sub: &Subtitle) -> LineTranslation {
LineTranslation {
original: sub.lines.join(" "),
translation: None,
}
}
}