use serde::{Deserialize, Serialize};
use crate::error::{Result, SubtitleToolkitError};
use super::{TranslationRequest, Translator};
#[derive(Debug, Clone)]
pub struct GeminiTranslator {
client: reqwest::Client,
base_url: String,
api_key: String,
model: String,
}
impl GeminiTranslator {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
Self::with_base_url("https://generativelanguage.googleapis.com", api_key, model)
}
pub fn with_base_url(
base_url: impl Into<String>,
api_key: impl Into<String>,
model: impl Into<String>,
) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.map_err(SubtitleToolkitError::Http)?;
Ok(Self {
client,
base_url: base_url.into().trim_end_matches('/').to_string(),
api_key: api_key.into(),
model: model.into(),
})
}
}
#[async_trait::async_trait]
impl Translator for GeminiTranslator {
async fn translate(&self, request: TranslationRequest<'_>) -> Result<String> {
let url = format!(
"{}/v1beta/models/{}:generateContent",
self.base_url, self.model
);
let response = self
.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(&GenerateContentRequest {
system_instruction: SystemInstruction {
parts: vec![Part {
text: "You are a subtitle translator. \
Translate the user text to the requested language. \
Preserve every numeric tag exactly, like <1>, <2>, <3>. \
Return only translated subtitle lines. \
Do not add explanations, markdown, notes, or code fences. \
Keep line breaks inside each subtitle when needed. \
Do not add curly-brace commands or backslash formatting."
.to_string(),
}],
},
contents: vec![Content {
parts: vec![Part {
text: format!(
"Target language: {lang}\n\n{text}",
lang = request.target_language,
text = request.source_text,
),
}],
}],
generation_config: GenerationConfig {
temperature: 0.2,
},
})
.send()
.await?;
if !response.status().is_success() {
return Err(SubtitleToolkitError::Translation {
provider: "gemini",
message: response
.text()
.await
.unwrap_or_else(|_| "request failed".into()),
});
}
let body = response.json::<GenerateContentResponse>().await?;
let text = body
.candidates
.into_iter()
.next()
.ok_or_else(|| SubtitleToolkitError::Translation {
provider: "gemini",
message: "response contained no candidates".into(),
})?
.content
.parts
.into_iter()
.next()
.ok_or_else(|| SubtitleToolkitError::Translation {
provider: "gemini",
message: "response candidate contained no parts".into(),
})?
.text;
Ok(text.trim().to_string())
}
}
#[derive(Debug, Serialize)]
struct GenerateContentRequest {
system_instruction: SystemInstruction,
contents: Vec<Content>,
generation_config: GenerationConfig,
}
#[derive(Debug, Serialize)]
struct SystemInstruction {
parts: Vec<Part>,
}
#[derive(Debug, Serialize)]
struct Content {
parts: Vec<Part>,
}
#[derive(Debug, Serialize)]
struct Part {
text: String,
}
#[derive(Debug, Serialize)]
struct GenerationConfig {
temperature: f32,
}
#[derive(Debug, Deserialize)]
struct GenerateContentResponse {
candidates: Vec<Candidate>,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: CandidateContent,
}
#[derive(Debug, Deserialize)]
struct CandidateContent {
parts: Vec<CandidatePart>,
}
#[derive(Debug, Deserialize)]
struct CandidatePart {
text: String,
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::{method, path_regex};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn gemini_response(text: &str) -> serde_json::Value {
serde_json::json!({
"candidates": [{
"content": {
"parts": [{ "text": text }],
"role": "model"
},
"finishReason": "STOP"
}]
})
}
#[tokio::test]
async fn translates_numbered_text() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path_regex(r"/v1beta/models/.*:generateContent"))
.respond_with(ResponseTemplate::new(200).set_body_json(gemini_response(
"<1> Olá\n<2> mundo",
)))
.mount(&server)
.await;
let translator =
GeminiTranslator::with_base_url(server.uri(), "test-key", "gemini-2.0-flash")
.unwrap();
let result = translator
.translate(TranslationRequest {
source_text: "<1> hello\n<2> world",
target_language: "pt-BR",
})
.await
.unwrap();
assert_eq!(result, "<1> Olá\n<2> mundo");
}
#[tokio::test]
async fn sends_api_key_header() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(wiremock::matchers::header(
"x-goog-api-key",
"my-secret-key",
))
.respond_with(ResponseTemplate::new(200).set_body_json(gemini_response("<1> ok")))
.mount(&server)
.await;
let translator =
GeminiTranslator::with_base_url(server.uri(), "my-secret-key", "gemini-2.0-flash")
.unwrap();
translator
.translate(TranslationRequest {
source_text: "<1> test",
target_language: "de",
})
.await
.unwrap();
}
#[tokio::test]
async fn trims_whitespace_from_response() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(gemini_response(
" <1> Olá \n",
)))
.mount(&server)
.await;
let translator =
GeminiTranslator::with_base_url(server.uri(), "test-key", "gemini-2.0-flash")
.unwrap();
let result = translator
.translate(TranslationRequest {
source_text: "<1> hello",
target_language: "pt-BR",
})
.await
.unwrap();
assert_eq!(result, "<1> Olá");
}
#[tokio::test]
async fn error_on_non_200() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(429).set_body_string(r#"{"error": {"message": "Quota exceeded"}}"#),
)
.mount(&server)
.await;
let translator =
GeminiTranslator::with_base_url(server.uri(), "bad-key", "gemini-2.0-flash").unwrap();
let err = translator
.translate(TranslationRequest {
source_text: "<1> hello",
target_language: "pt-BR",
})
.await
.unwrap_err();
assert!(err.to_string().contains("gemini"));
assert!(err.to_string().contains("Quota exceeded"));
}
#[tokio::test]
async fn error_on_empty_candidates() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({ "candidates": [] })),
)
.mount(&server)
.await;
let translator =
GeminiTranslator::with_base_url(server.uri(), "test-key", "gemini-2.0-flash")
.unwrap();
let err = translator
.translate(TranslationRequest {
source_text: "<1> hello",
target_language: "pt-BR",
})
.await
.unwrap_err();
assert!(err.to_string().contains("no candidates"));
}
#[tokio::test]
async fn error_on_empty_parts() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"candidates": [{
"content": { "parts": [], "role": "model" },
"finishReason": "STOP"
}]
})))
.mount(&server)
.await;
let translator =
GeminiTranslator::with_base_url(server.uri(), "test-key", "gemini-2.0-flash")
.unwrap();
let err = translator
.translate(TranslationRequest {
source_text: "<1> hello",
target_language: "pt-BR",
})
.await
.unwrap_err();
assert!(err.to_string().contains("no parts"));
}
}