Skip to main content

psyche_subtitle_toolkit/translation/
deepl.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::{Result, SubtitleToolkitError};
4
5use super::{TranslationRequest, Translator};
6
7/// Translator backend that calls the [DeepL](https://www.deepl.com) `/v2/translate` endpoint.
8///
9/// Supports both the free tier (`https://api-free.deepl.com`) and the pro tier
10/// (`https://api.deepl.com`). The default base URL targets the free tier.
11///
12/// # Example
13///
14/// ```no_run
15/// # async fn example() -> psyche_subtitle_toolkit::Result<()> {
16/// use psyche_subtitle_toolkit::DeepLTranslator;
17///
18/// let translator = DeepLTranslator::new("your-api-key")?;
19/// // let result = translator.translate(request).await?;
20/// # Ok(())
21/// # }
22/// ```
23#[derive(Debug, Clone)]
24pub struct DeepLTranslator {
25    client: reqwest::Client,
26    base_url: String,
27    api_key: String,
28}
29
30impl DeepLTranslator {
31    /// Create a new translator targeting the DeepL free API (`https://api-free.deepl.com`).
32    pub fn new(api_key: impl Into<String>) -> Result<Self> {
33        Self::with_base_url("https://api-free.deepl.com", api_key)
34    }
35
36    /// Create a new translator with a custom base URL.
37    ///
38    /// Use `"https://api.deepl.com"` for the pro tier, or
39    /// `"https://api-free.deepl.com"` for the free tier (default).
40    pub fn with_base_url(
41        base_url: impl Into<String>,
42        api_key: impl Into<String>,
43    ) -> Result<Self> {
44        let client = reqwest::Client::builder()
45            .timeout(std::time::Duration::from_secs(120))
46            .build()
47            .map_err(SubtitleToolkitError::Http)?;
48        Ok(Self {
49            client,
50            base_url: base_url.into().trim_end_matches('/').to_string(),
51            api_key: api_key.into(),
52        })
53    }
54}
55
56#[async_trait::async_trait]
57impl Translator for DeepLTranslator {
58    async fn translate(&self, request: TranslationRequest<'_>) -> Result<String> {
59        let response = self
60            .client
61            .post(format!("{}/v2/translate", self.base_url))
62            .header(
63                "Authorization",
64                format!("DeepL-Auth-Key {}", self.api_key),
65            )
66            .json(&DeepLTranslateRequest {
67                text: request.source_text.lines().collect(),
68                target_lang: &request.target_language.to_uppercase(),
69                source_lang: request.source_language.map(|l| l.to_uppercase()),
70                split_sentences: "0",
71                preserve_formatting: true,
72            })
73            .send()
74            .await?;
75
76        if !response.status().is_success() {
77            return Err(SubtitleToolkitError::Translation {
78                provider: "deepl",
79                message: response
80                    .text()
81                    .await
82                    .unwrap_or_else(|_| "request failed".into()),
83            });
84        }
85
86        let body = response.json::<DeepLTranslateResponse>().await?;
87        if body.translations.is_empty() {
88            return Err(SubtitleToolkitError::Translation {
89                provider: "deepl",
90                message: "response contained no translations".into(),
91            });
92        }
93        let translated = body
94            .translations
95            .into_iter()
96            .map(|t| t.text)
97            .collect::<Vec<_>>()
98            .join("\n");
99
100        Ok(translated)
101    }
102}
103
104#[derive(Debug, Serialize)]
105struct DeepLTranslateRequest<'a> {
106    text: Vec<&'a str>,
107    target_lang: &'a str,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    source_lang: Option<String>,
110    split_sentences: &'static str,
111    preserve_formatting: bool,
112}
113
114#[derive(Debug, Deserialize)]
115struct DeepLTranslateResponse {
116    translations: Vec<DeepLTranslation>,
117}
118
119#[derive(Debug, Deserialize)]
120struct DeepLTranslation {
121    text: String,
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use wiremock::matchers::{header, method, path};
128    use wiremock::{Mock, MockServer, ResponseTemplate};
129
130    #[tokio::test]
131    async fn translates_numbered_text() {
132        let server = MockServer::start().await;
133
134        Mock::given(method("POST"))
135            .and(path("/v2/translate"))
136            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
137                "translations": [{ "text": "<1> Olá" }, { "text": "<2> mundo" }]
138            })))
139            .mount(&server)
140            .await;
141
142        let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
143        let result = translator
144            .translate(TranslationRequest {
145                source_text: "<1> hello\n<2> world",
146                target_language: "pt-BR",
147                source_language: None,
148            })
149            .await
150            .unwrap();
151
152        assert_eq!(result, "<1> Olá\n<2> mundo");
153    }
154
155    #[tokio::test]
156    async fn sends_deepl_auth_header() {
157        let server = MockServer::start().await;
158
159        Mock::given(method("POST"))
160            .and(header("Authorization", "DeepL-Auth-Key my-secret-key"))
161            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
162                "translations": [{ "text": "<1> ok" }]
163            })))
164            .mount(&server)
165            .await;
166
167        let translator = DeepLTranslator::with_base_url(server.uri(), "my-secret-key").unwrap();
168        translator
169            .translate(TranslationRequest {
170                source_text: "<1> test",
171                target_language: "de",
172            source_language: None,
173            })
174            .await
175            .unwrap();
176    }
177
178    #[tokio::test]
179    async fn uppercases_target_language() {
180        let server = MockServer::start().await;
181
182        Mock::given(method("POST"))
183            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
184                "translations": [{ "text": "<1> ok" }]
185            })))
186            .mount(&server)
187            .await;
188
189        let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
190        // "pt-BR" should be sent as "PT-BR" in the request body
191        translator
192            .translate(TranslationRequest {
193                source_text: "<1> test",
194                target_language: "pt-BR",
195            source_language: None,
196            })
197            .await
198            .unwrap();
199    }
200
201    #[tokio::test]
202    async fn error_on_non_200() {
203        let server = MockServer::start().await;
204
205        Mock::given(method("POST"))
206            .respond_with(
207                ResponseTemplate::new(403).set_body_string("Quota exceeded"),
208            )
209            .mount(&server)
210            .await;
211
212        let translator = DeepLTranslator::with_base_url(server.uri(), "bad-key").unwrap();
213        let err = translator
214            .translate(TranslationRequest {
215                source_text: "<1> hello",
216                target_language: "de",
217            source_language: None,
218            })
219            .await
220            .unwrap_err();
221
222        assert!(err.to_string().contains("deepl"));
223        assert!(err.to_string().contains("Quota exceeded"));
224    }
225
226    #[tokio::test]
227    async fn error_on_empty_translations() {
228        let server = MockServer::start().await;
229
230        Mock::given(method("POST"))
231            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
232                "translations": []
233            })))
234            .mount(&server)
235            .await;
236
237        let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
238        let err = translator
239            .translate(TranslationRequest {
240                source_text: "<1> hello",
241                target_language: "de",
242            source_language: None,
243            })
244            .await
245            .unwrap_err();
246
247        assert!(err.to_string().contains("no translations"));
248    }
249
250    #[tokio::test]
251    async fn sends_source_lang_when_provided() {
252        let server = MockServer::start().await;
253
254        Mock::given(method("POST"))
255            .and(path("/v2/translate"))
256            .and(wiremock::matchers::body_string_contains(r#""source_lang":"JA"#))
257            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
258                "translations": [{ "text": "<1> Olá" }]
259            })))
260            .expect(1)
261            .mount(&server)
262            .await;
263
264        let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
265        translator
266            .translate(TranslationRequest {
267                source_text: "<1> hello",
268                target_language: "pt-BR",
269                source_language: Some("ja"),
270            })
271            .await
272            .unwrap();
273    }
274
275    #[tokio::test]
276    async fn sends_each_line_as_separate_array_element() {
277        let server = MockServer::start().await;
278
279        Mock::given(method("POST"))
280            .and(path("/v2/translate"))
281            .and(wiremock::matchers::body_string_contains(r#""<1> hello""#))
282            .and(wiremock::matchers::body_string_contains(r#""<2> world""#))
283            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
284                "translations": [{ "text": "<1> Olá" }, { "text": "<2> mundo" }]
285            })))
286            .expect(1)
287            .mount(&server)
288            .await;
289
290        let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
291        translator
292            .translate(TranslationRequest {
293                source_text: "<1> hello\n<2> world",
294                target_language: "pt-BR",
295                source_language: None,
296            })
297            .await
298            .unwrap();
299    }
300
301    #[tokio::test]
302    async fn joins_translated_lines_back() {
303        let server = MockServer::start().await;
304
305        Mock::given(method("POST"))
306            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
307                "translations": [
308                    { "text": "<1> Zeile eins" },
309                    { "text": "<2> Zeile zwei" },
310                    { "text": "<3> Zeile drei" }
311                ]
312            })))
313            .mount(&server)
314            .await;
315
316        let translator = DeepLTranslator::with_base_url(server.uri(), "test-key").unwrap();
317        let result = translator
318            .translate(TranslationRequest {
319                source_text: "<1> Line one\n<2> Line two\n<3> Line three",
320                target_language: "de",
321                source_language: None,
322            })
323            .await
324            .unwrap();
325
326        assert_eq!(result, "<1> Zeile eins\n<2> Zeile zwei\n<3> Zeile drei");
327    }
328}