Skip to main content

psyche_subtitle_toolkit/translation/
google.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::{Result, SubtitleToolkitError};
4
5use super::{TranslationRequest, Translator};
6
7/// Translator backend that calls the [Google Cloud Translation v2](https://cloud.google.com/translate/docs/reference/rest/v2/translate) endpoint.
8///
9/// Uses API key authentication via the `key` query parameter.
10/// The first 500,000 characters/month are free.
11///
12/// # Example
13///
14/// ```no_run
15/// # async fn example() -> psyche_subtitle_toolkit::Result<()> {
16/// use psyche_subtitle_toolkit::GoogleTranslator;
17///
18/// let translator = GoogleTranslator::new("your-api-key")?;
19/// // let result = translator.translate(request).await?;
20/// # Ok(())
21/// # }
22/// ```
23#[derive(Debug, Clone)]
24pub struct GoogleTranslator {
25    client: reqwest::Client,
26    base_url: String,
27    api_key: String,
28}
29
30impl GoogleTranslator {
31    /// Create a new translator targeting the default Google Translate API
32    /// (`https://translation.googleapis.com`).
33    pub fn new(api_key: impl Into<String>) -> Result<Self> {
34        Self::with_base_url("https://translation.googleapis.com", api_key)
35    }
36
37    /// Create a new translator with a custom base URL and API key.
38    pub fn with_base_url(
39        base_url: impl Into<String>,
40        api_key: impl Into<String>,
41    ) -> Result<Self> {
42        let client = reqwest::Client::builder()
43            .timeout(std::time::Duration::from_secs(120))
44            .build()
45            .map_err(SubtitleToolkitError::Http)?;
46        Ok(Self {
47            client,
48            base_url: base_url.into().trim_end_matches('/').to_string(),
49            api_key: api_key.into(),
50        })
51    }
52}
53
54#[async_trait::async_trait]
55impl Translator for GoogleTranslator {
56    async fn translate(&self, request: TranslationRequest<'_>) -> Result<String> {
57        // Google v2 uses language codes like "pt", "en", "ja" — not "pt-BR".
58        // We pass the target as-is; Google handles regional codes (e.g. "pt-BR" works).
59        let response = self
60            .client
61            .post(format!(
62                "{}/language/translate/v2",
63                self.base_url
64            ))
65            .query(&[("key", &self.api_key)])
66            .json(&GoogleTranslateRequest {
67                q: request.source_text.lines().collect(),
68                target: request.target_language,
69                source: request.source_language,
70                format: "text",
71            })
72            .send()
73            .await?;
74
75        if !response.status().is_success() {
76            return Err(SubtitleToolkitError::Translation {
77                provider: "google",
78                message: response
79                    .text()
80                    .await
81                    .unwrap_or_else(|_| "request failed".into()),
82            });
83        }
84
85        let body = response.json::<GoogleTranslateResponse>().await?;
86        if body.data.translations.is_empty() {
87            return Err(SubtitleToolkitError::Translation {
88                provider: "google",
89                message: "response contained no translations".into(),
90            });
91        }
92        let translated = body
93            .data
94            .translations
95            .into_iter()
96            .map(|t| t.translated_text)
97            .collect::<Vec<_>>()
98            .join("\n");
99
100        Ok(translated)
101    }
102}
103
104#[derive(Debug, Serialize)]
105struct GoogleTranslateRequest<'a> {
106    q: Vec<&'a str>,
107    target: &'a str,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    source: Option<&'a str>,
110    format: &'static str,
111}
112
113#[derive(Debug, Deserialize)]
114struct GoogleTranslateResponse {
115    data: GoogleTranslateData,
116}
117
118#[derive(Debug, Deserialize)]
119struct GoogleTranslateData {
120    translations: Vec<GoogleTranslation>,
121}
122
123#[derive(Debug, Deserialize)]
124struct GoogleTranslation {
125    #[serde(rename = "translatedText")]
126    translated_text: String,
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use wiremock::matchers::{method, path};
133    use wiremock::{Mock, MockServer, ResponseTemplate};
134
135    #[tokio::test]
136    async fn translates_numbered_text() {
137        let server = MockServer::start().await;
138
139        Mock::given(method("POST"))
140            .and(path("/language/translate/v2"))
141            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
142                "data": {
143                    "translations": [{ "translatedText": "<1> Olá" }, { "translatedText": "<2> mundo" }]
144                }
145            })))
146            .mount(&server)
147            .await;
148
149        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
150        let result = translator
151            .translate(TranslationRequest {
152                source_text: "<1> hello\n<2> world",
153                target_language: "pt",
154                source_language: None,
155            })
156            .await
157            .unwrap();
158
159        assert_eq!(result, "<1> Olá\n<2> mundo");
160    }
161
162    #[tokio::test]
163    async fn sends_api_key_as_query_param() {
164        let server = MockServer::start().await;
165
166        // wiremock will match the path; we also verify the request was made
167        Mock::given(method("POST"))
168            .and(path("/language/translate/v2"))
169            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
170                "data": {
171                    "translations": [{ "translatedText": "<1> ok" }]
172                }
173            })))
174            .expect(1)
175            .mount(&server)
176            .await;
177
178        let translator = GoogleTranslator::with_base_url(server.uri(), "my-key").unwrap();
179        translator
180            .translate(TranslationRequest {
181                source_text: "<1> test",
182                target_language: "de",
183            source_language: None,
184            })
185            .await
186            .unwrap();
187    }
188
189    #[tokio::test]
190    async fn sends_format_text() {
191        let server = MockServer::start().await;
192
193        Mock::given(method("POST"))
194            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
195                "data": {
196                    "translations": [{ "translatedText": "<1> ok" }]
197                }
198            })))
199            .mount(&server)
200            .await;
201
202        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
203        // "format": "text" prevents HTML interpretation of ASS tags
204        translator
205            .translate(TranslationRequest {
206                source_text: r"<1> {\b1}Bold text",
207                target_language: "es",
208            source_language: None,
209            })
210            .await
211            .unwrap();
212    }
213
214    #[tokio::test]
215    async fn error_on_non_200() {
216        let server = MockServer::start().await;
217
218        Mock::given(method("POST"))
219            .respond_with(
220                ResponseTemplate::new(403).set_body_string(r#"{"error": {"message": "Daily Limit Exceeded"}}"#),
221            )
222            .mount(&server)
223            .await;
224
225        let translator = GoogleTranslator::with_base_url(server.uri(), "bad-key").unwrap();
226        let err = translator
227            .translate(TranslationRequest {
228                source_text: "<1> hello",
229                target_language: "pt",
230            source_language: None,
231            })
232            .await
233            .unwrap_err();
234
235        assert!(err.to_string().contains("google"));
236        assert!(err.to_string().contains("Daily Limit Exceeded"));
237    }
238
239    #[tokio::test]
240    async fn error_on_empty_translations() {
241        let server = MockServer::start().await;
242
243        Mock::given(method("POST"))
244            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
245                "data": { "translations": [] }
246            })))
247            .mount(&server)
248            .await;
249
250        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
251        let err = translator
252            .translate(TranslationRequest {
253                source_text: "<1> hello",
254                target_language: "pt",
255            source_language: None,
256            })
257            .await
258            .unwrap_err();
259
260        assert!(err.to_string().contains("no translations"));
261    }
262
263    #[tokio::test]
264    async fn sends_source_when_provided() {
265        let server = MockServer::start().await;
266
267        Mock::given(method("POST"))
268            .and(path("/language/translate/v2"))
269            .and(wiremock::matchers::body_string_contains(r#""source":"ja""#))
270            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
271                "data": { "translations": [{ "translatedText": "<1> Olá" }] }
272            })))
273            .expect(1)
274            .mount(&server)
275            .await;
276
277        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
278        translator
279            .translate(TranslationRequest {
280                source_text: "<1> hello",
281                target_language: "pt",
282                source_language: Some("ja"),
283            })
284            .await
285            .unwrap();
286    }
287
288    #[tokio::test]
289    async fn sends_each_line_as_separate_array_element() {
290        let server = MockServer::start().await;
291
292        Mock::given(method("POST"))
293            .and(path("/language/translate/v2"))
294            .and(wiremock::matchers::body_string_contains(r#""<1> hello""#))
295            .and(wiremock::matchers::body_string_contains(r#""<2> world""#))
296            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
297                "data": {
298                    "translations": [{ "translatedText": "<1> Olá" }, { "translatedText": "<2> mundo" }]
299                }
300            })))
301            .expect(1)
302            .mount(&server)
303            .await;
304
305        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
306        translator
307            .translate(TranslationRequest {
308                source_text: "<1> hello\n<2> world",
309                target_language: "pt",
310                source_language: None,
311            })
312            .await
313            .unwrap();
314    }
315
316    #[tokio::test]
317    async fn joins_translated_lines_back() {
318        let server = MockServer::start().await;
319
320        Mock::given(method("POST"))
321            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
322                "data": {
323                    "translations": [
324                        { "translatedText": "<1> Zeile eins" },
325                        { "translatedText": "<2> Zeile zwei" },
326                        { "translatedText": "<3> Zeile drei" }
327                    ]
328                }
329            })))
330            .mount(&server)
331            .await;
332
333        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
334        let result = translator
335            .translate(TranslationRequest {
336                source_text: "<1> Line one\n<2> Line two\n<3> Line three",
337                target_language: "de",
338                source_language: None,
339            })
340            .await
341            .unwrap();
342
343        assert_eq!(result, "<1> Zeile eins\n<2> Zeile zwei\n<3> Zeile drei");
344    }
345}