Skip to main content

psyche_subtitle_toolkit/translation/
google.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::{Result, SubtitleToolkitError};
4
5use super::{TranslationRequest, Translator};
6
7/// Translator backend that calls the [Google Cloud Translation v2](https://cloud.google.com/translate/docs/reference/rest/v2/translate) endpoint.
8///
9/// Uses API key authentication via the `key` query parameter.
10/// The first 500,000 characters/month are free.
11///
12/// # Example
13///
14/// ```no_run
15/// # async fn example() -> psyche_subtitle_toolkit::Result<()> {
16/// use psyche_subtitle_toolkit::GoogleTranslator;
17///
18/// let translator = GoogleTranslator::new("your-api-key")?;
19/// // let result = translator.translate(request).await?;
20/// # Ok(())
21/// # }
22/// ```
23#[derive(Debug, Clone)]
24pub struct GoogleTranslator {
25    client: reqwest::Client,
26    base_url: String,
27    api_key: String,
28}
29
30impl GoogleTranslator {
31    /// Create a new translator targeting the default Google Translate API
32    /// (`https://translation.googleapis.com`).
33    pub fn new(api_key: impl Into<String>) -> Result<Self> {
34        Self::with_base_url("https://translation.googleapis.com", api_key)
35    }
36
37    /// Create a new translator with a custom base URL and API key.
38    pub fn with_base_url(
39        base_url: impl Into<String>,
40        api_key: impl Into<String>,
41    ) -> Result<Self> {
42        let client = reqwest::Client::builder()
43            .timeout(std::time::Duration::from_secs(120))
44            .build()
45            .map_err(SubtitleToolkitError::Http)?;
46        Ok(Self {
47            client,
48            base_url: base_url.into().trim_end_matches('/').to_string(),
49            api_key: api_key.into(),
50        })
51    }
52}
53
54#[async_trait::async_trait]
55impl Translator for GoogleTranslator {
56    async fn translate(&self, request: TranslationRequest<'_>) -> Result<String> {
57        // Google v2 uses language codes like "pt", "en", "ja" — not "pt-BR".
58        // We pass the target as-is; Google handles regional codes (e.g. "pt-BR" works).
59        let response = self
60            .client
61            .post(format!(
62                "{}/language/translate/v2",
63                self.base_url
64            ))
65            .query(&[("key", &self.api_key)])
66            .json(&GoogleTranslateRequest {
67                q: request.source_text.lines().collect(),
68                target: request.target_language,
69                format: "text",
70            })
71            .send()
72            .await?;
73
74        if !response.status().is_success() {
75            return Err(SubtitleToolkitError::Translation {
76                provider: "google",
77                message: response
78                    .text()
79                    .await
80                    .unwrap_or_else(|_| "request failed".into()),
81            });
82        }
83
84        let body = response.json::<GoogleTranslateResponse>().await?;
85        if body.data.translations.is_empty() {
86            return Err(SubtitleToolkitError::Translation {
87                provider: "google",
88                message: "response contained no translations".into(),
89            });
90        }
91        let translated = body
92            .data
93            .translations
94            .into_iter()
95            .map(|t| t.translated_text)
96            .collect::<Vec<_>>()
97            .join("\n");
98
99        Ok(translated)
100    }
101}
102
103#[derive(Debug, Serialize)]
104struct GoogleTranslateRequest<'a> {
105    q: Vec<&'a str>,
106    target: &'a str,
107    format: &'static str,
108}
109
110#[derive(Debug, Deserialize)]
111struct GoogleTranslateResponse {
112    data: GoogleTranslateData,
113}
114
115#[derive(Debug, Deserialize)]
116struct GoogleTranslateData {
117    translations: Vec<GoogleTranslation>,
118}
119
120#[derive(Debug, Deserialize)]
121struct GoogleTranslation {
122    #[serde(rename = "translatedText")]
123    translated_text: String,
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use wiremock::matchers::{method, path};
130    use wiremock::{Mock, MockServer, ResponseTemplate};
131
132    #[tokio::test]
133    async fn translates_numbered_text() {
134        let server = MockServer::start().await;
135
136        Mock::given(method("POST"))
137            .and(path("/language/translate/v2"))
138            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
139                "data": {
140                    "translations": [{ "translatedText": "<1> Olá" }, { "translatedText": "<2> mundo" }]
141                }
142            })))
143            .mount(&server)
144            .await;
145
146        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
147        let result = translator
148            .translate(TranslationRequest {
149                source_text: "<1> hello\n<2> world",
150                target_language: "pt",
151            })
152            .await
153            .unwrap();
154
155        assert_eq!(result, "<1> Olá\n<2> mundo");
156    }
157
158    #[tokio::test]
159    async fn sends_api_key_as_query_param() {
160        let server = MockServer::start().await;
161
162        // wiremock will match the path; we also verify the request was made
163        Mock::given(method("POST"))
164            .and(path("/language/translate/v2"))
165            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
166                "data": {
167                    "translations": [{ "translatedText": "<1> ok" }]
168                }
169            })))
170            .expect(1)
171            .mount(&server)
172            .await;
173
174        let translator = GoogleTranslator::with_base_url(server.uri(), "my-key").unwrap();
175        translator
176            .translate(TranslationRequest {
177                source_text: "<1> test",
178                target_language: "de",
179            })
180            .await
181            .unwrap();
182    }
183
184    #[tokio::test]
185    async fn sends_format_text() {
186        let server = MockServer::start().await;
187
188        Mock::given(method("POST"))
189            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
190                "data": {
191                    "translations": [{ "translatedText": "<1> ok" }]
192                }
193            })))
194            .mount(&server)
195            .await;
196
197        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
198        // "format": "text" prevents HTML interpretation of ASS tags
199        translator
200            .translate(TranslationRequest {
201                source_text: r"<1> {\b1}Bold text",
202                target_language: "es",
203            })
204            .await
205            .unwrap();
206    }
207
208    #[tokio::test]
209    async fn error_on_non_200() {
210        let server = MockServer::start().await;
211
212        Mock::given(method("POST"))
213            .respond_with(
214                ResponseTemplate::new(403).set_body_string(r#"{"error": {"message": "Daily Limit Exceeded"}}"#),
215            )
216            .mount(&server)
217            .await;
218
219        let translator = GoogleTranslator::with_base_url(server.uri(), "bad-key").unwrap();
220        let err = translator
221            .translate(TranslationRequest {
222                source_text: "<1> hello",
223                target_language: "pt",
224            })
225            .await
226            .unwrap_err();
227
228        assert!(err.to_string().contains("google"));
229        assert!(err.to_string().contains("Daily Limit Exceeded"));
230    }
231
232    #[tokio::test]
233    async fn error_on_empty_translations() {
234        let server = MockServer::start().await;
235
236        Mock::given(method("POST"))
237            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
238                "data": { "translations": [] }
239            })))
240            .mount(&server)
241            .await;
242
243        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
244        let err = translator
245            .translate(TranslationRequest {
246                source_text: "<1> hello",
247                target_language: "pt",
248            })
249            .await
250            .unwrap_err();
251
252        assert!(err.to_string().contains("no translations"));
253    }
254
255    #[tokio::test]
256    async fn sends_each_line_as_separate_array_element() {
257        let server = MockServer::start().await;
258
259        Mock::given(method("POST"))
260            .and(path("/language/translate/v2"))
261            .and(wiremock::matchers::body_string_contains(r#""<1> hello""#))
262            .and(wiremock::matchers::body_string_contains(r#""<2> world""#))
263            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
264                "data": {
265                    "translations": [{ "translatedText": "<1> Olá" }, { "translatedText": "<2> mundo" }]
266                }
267            })))
268            .expect(1)
269            .mount(&server)
270            .await;
271
272        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
273        translator
274            .translate(TranslationRequest {
275                source_text: "<1> hello\n<2> world",
276                target_language: "pt",
277            })
278            .await
279            .unwrap();
280    }
281
282    #[tokio::test]
283    async fn joins_translated_lines_back() {
284        let server = MockServer::start().await;
285
286        Mock::given(method("POST"))
287            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
288                "data": {
289                    "translations": [
290                        { "translatedText": "<1> Zeile eins" },
291                        { "translatedText": "<2> Zeile zwei" },
292                        { "translatedText": "<3> Zeile drei" }
293                    ]
294                }
295            })))
296            .mount(&server)
297            .await;
298
299        let translator = GoogleTranslator::with_base_url(server.uri(), "test-key").unwrap();
300        let result = translator
301            .translate(TranslationRequest {
302                source_text: "<1> Line one\n<2> Line two\n<3> Line three",
303                target_language: "de",
304            })
305            .await
306            .unwrap();
307
308        assert_eq!(result, "<1> Zeile eins\n<2> Zeile zwei\n<3> Zeile drei");
309    }
310}