psyche_subtitle_toolkit/translation/
deepl.rs1use serde::{Deserialize, Serialize};
2
3use crate::error::{Result, SubtitleToolkitError};
4
5use super::{TranslationRequest, Translator};
6
7#[derive(Debug, Clone)]
24pub struct DeepLTranslator {
25 client: reqwest::Client,
26 base_url: String,
27 api_key: String,
28}
29
30impl DeepLTranslator {
31 pub fn new(api_key: impl Into<String>) -> Result<Self> {
33 Self::with_base_url("https://api-free.deepl.com", api_key)
34 }
35
36 pub fn with_base_url(
41 base_url: impl Into<String>,
42 api_key: impl Into<String>,
43 ) -> Result<Self> {
44 let client = reqwest::Client::builder()
45 .timeout(std::time::Duration::from_secs(120))
46 .build()
47 .map_err(SubtitleToolkitError::Http)?;
48 Ok(Self {
49 client,
50 base_url: base_url.into().trim_end_matches('/').to_string(),
51 api_key: api_key.into(),
52 })
53 }
54}
55
56#[async_trait::async_trait]
57impl Translator for DeepLTranslator {
58 async fn translate(&self, request: TranslationRequest<'_>) -> Result<String> {
59 let response = self
60 .client
61 .post(format!("{}/v2/translate", self.base_url))
62 .header(
63 "Authorization",
64 format!("DeepL-Auth-Key {}", self.api_key),
65 )
66 .json(&DeepLTranslateRequest {
67 text: request.source_text.lines().collect(),
68 target_lang: &request.target_language.to_uppercase(),
69 source_lang: request.source_language.map(|l| l.to_uppercase()),
70 split_sentences: "0",
71 preserve_formatting: true,
72 })
73 .send()
74 .await?;
75
76 if !response.status().is_success() {
77 return Err(SubtitleToolkitError::Translation {
78 provider: "deepl",
79 message: response
80 .text()
81 .await
82 .unwrap_or_else(|_| "request failed".into()),
83 });
84 }
85
86 let body = response.json::<DeepLTranslateResponse>().await?;
87 if body.translations.is_empty() {
88 return Err(SubtitleToolkitError::Translation {
89 provider: "deepl",
90 message: "response contained no translations".into(),
91 });
92 }
93 let translated = body
94 .translations
95 .into_iter()
96 .map(|t| t.text)
97 .collect::<Vec<_>>()
98 .join("\n");
99
100 Ok(translated)
101 }
102}
103
104#[derive(Debug, Serialize)]
105struct DeepLTranslateRequest<'a> {
106 text: Vec<&'a str>,
107 target_lang: &'a str,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 source_lang: Option<String>,
110 split_sentences: &'static str,
111 preserve_formatting: bool,
112}
113
114#[derive(Debug, Deserialize)]
115struct DeepLTranslateResponse {
116 translations: Vec<DeepLTranslation>,
117}
118
119#[derive(Debug, Deserialize)]
120struct DeepLTranslation {
121 text: String,
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use wiremock::matchers::{header, method, path};
128 use wiremock::{Mock, MockServer, ResponseTemplate};
129
130 #[tokio::test]
131 async fn translates_numbered_text() {
132 let server = MockServer::start().await;
133
134 Mock::given(method("POST"))
135 .and(path("/v2/translate"))
136 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
137 "translations": [{ "text": "<1> Olá" }, { "text": "<2> mundo" }]
138 })))
139 .mount(&server)
140 .await;
141
142 let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
143 let result = translator
144 .translate(TranslationRequest {
145 source_text: "<1> hello\n<2> world",
146 target_language: "pt-BR",
147 source_language: None,
148 })
149 .await
150 .unwrap();
151
152 assert_eq!(result, "<1> Olá\n<2> mundo");
153 }
154
155 #[tokio::test]
156 async fn sends_deepl_auth_header() {
157 let server = MockServer::start().await;
158
159 Mock::given(method("POST"))
160 .and(header("Authorization", "DeepL-Auth-Key my-secret-key"))
161 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
162 "translations": [{ "text": "<1> ok" }]
163 })))
164 .mount(&server)
165 .await;
166
167 let translator = DeepLTranslator::with_base_url(server.uri(), "my-secret-key").unwrap();
168 translator
169 .translate(TranslationRequest {
170 source_text: "<1> test",
171 target_language: "de",
172 source_language: None,
173 })
174 .await
175 .unwrap();
176 }
177
178 #[tokio::test]
179 async fn uppercases_target_language() {
180 let server = MockServer::start().await;
181
182 Mock::given(method("POST"))
183 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
184 "translations": [{ "text": "<1> ok" }]
185 })))
186 .mount(&server)
187 .await;
188
189 let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
190 translator
192 .translate(TranslationRequest {
193 source_text: "<1> test",
194 target_language: "pt-BR",
195 source_language: None,
196 })
197 .await
198 .unwrap();
199 }
200
201 #[tokio::test]
202 async fn error_on_non_200() {
203 let server = MockServer::start().await;
204
205 Mock::given(method("POST"))
206 .respond_with(
207 ResponseTemplate::new(403).set_body_string("Quota exceeded"),
208 )
209 .mount(&server)
210 .await;
211
212 let translator = DeepLTranslator::with_base_url(server.uri(), "bad-key").unwrap();
213 let err = translator
214 .translate(TranslationRequest {
215 source_text: "<1> hello",
216 target_language: "de",
217 source_language: None,
218 })
219 .await
220 .unwrap_err();
221
222 assert!(err.to_string().contains("deepl"));
223 assert!(err.to_string().contains("Quota exceeded"));
224 }
225
226 #[tokio::test]
227 async fn error_on_empty_translations() {
228 let server = MockServer::start().await;
229
230 Mock::given(method("POST"))
231 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
232 "translations": []
233 })))
234 .mount(&server)
235 .await;
236
237 let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
238 let err = translator
239 .translate(TranslationRequest {
240 source_text: "<1> hello",
241 target_language: "de",
242 source_language: None,
243 })
244 .await
245 .unwrap_err();
246
247 assert!(err.to_string().contains("no translations"));
248 }
249
250 #[tokio::test]
251 async fn sends_source_lang_when_provided() {
252 let server = MockServer::start().await;
253
254 Mock::given(method("POST"))
255 .and(path("/v2/translate"))
256 .and(wiremock::matchers::body_string_contains(r#""source_lang":"JA"#))
257 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
258 "translations": [{ "text": "<1> Olá" }]
259 })))
260 .expect(1)
261 .mount(&server)
262 .await;
263
264 let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
265 translator
266 .translate(TranslationRequest {
267 source_text: "<1> hello",
268 target_language: "pt-BR",
269 source_language: Some("ja"),
270 })
271 .await
272 .unwrap();
273 }
274
275 #[tokio::test]
276 async fn sends_each_line_as_separate_array_element() {
277 let server = MockServer::start().await;
278
279 Mock::given(method("POST"))
280 .and(path("/v2/translate"))
281 .and(wiremock::matchers::body_string_contains(r#""<1> hello""#))
282 .and(wiremock::matchers::body_string_contains(r#""<2> world""#))
283 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
284 "translations": [{ "text": "<1> Olá" }, { "text": "<2> mundo" }]
285 })))
286 .expect(1)
287 .mount(&server)
288 .await;
289
290 let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
291 translator
292 .translate(TranslationRequest {
293 source_text: "<1> hello\n<2> world",
294 target_language: "pt-BR",
295 source_language: None,
296 })
297 .await
298 .unwrap();
299 }
300
301 #[tokio::test]
302 async fn joins_translated_lines_back() {
303 let server = MockServer::start().await;
304
305 Mock::given(method("POST"))
306 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
307 "translations": [
308 { "text": "<1> Zeile eins" },
309 { "text": "<2> Zeile zwei" },
310 { "text": "<3> Zeile drei" }
311 ]
312 })))
313 .mount(&server)
314 .await;
315
316 let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
317 let result = translator
318 .translate(TranslationRequest {
319 source_text: "<1> Line one\n<2> Line two\n<3> Line three",
320 target_language: "de",
321 source_language: None,
322 })
323 .await
324 .unwrap();
325
326 assert_eq!(result, "<1> Zeile eins\n<2> Zeile zwei\n<3> Zeile drei");
327 }
328}