use std::collections::BTreeMap;
use std::error::Error as _;
use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use tokio::time::sleep;
use crate::domain::{ProviderConfig, ThinkingMode};
use crate::error::TranslatorError;
use crate::pipeline::prompt;
use super::{
ProviderCapabilities, TranslatedItem, TranslationBatch, TranslationBatchOutput,
TranslationProvider,
};
pub struct OpenAiCompatibleProvider {
client: reqwest::Client,
config: ProviderConfig,
}
impl OpenAiCompatibleProvider {
pub fn new(config: ProviderConfig) -> Result<Self, TranslatorError> {
config.validate()?;
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(config.timeout_seconds))
.build()?;
Ok(Self { client, config })
}
pub fn config(&self) -> &ProviderConfig {
&self.config
}
fn endpoint(&self) -> String {
format!(
"{}/chat/completions",
self.config.base_url.trim_end_matches('/')
)
}
fn build_headers(&self) -> Result<HeaderMap, TranslatorError> {
let mut headers = HeaderMap::new();
for (name, value) in &self.config.custom_headers {
let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|error| {
TranslatorError::InvalidConfig(format!("invalid header name {name:?}: {error}"))
})?;
let header_value = HeaderValue::from_str(value).map_err(|error| {
TranslatorError::InvalidConfig(format!(
"invalid header value for {name:?}: {error}"
))
})?;
headers.insert(header_name, header_value);
}
Ok(headers)
}
}
#[async_trait]
impl TranslationProvider for OpenAiCompatibleProvider {
fn name(&self) -> &str {
if self.is_gemini() {
"gemini"
} else {
"openai-compatible"
}
}
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
supports_json_output: true,
max_batch_items: None,
supports_reasoning_fallback: true,
supports_thinking_control: self.supports_known_thinking_control(),
}
}
async fn translate_batch(
&self,
batch: TranslationBatch,
) -> Result<TranslationBatchOutput, TranslatorError> {
let headers = self.build_headers()?;
let endpoint = self.endpoint();
let (system_prompt, user_prompt) =
prompt::build_messages(&batch, self.config.thinking_mode);
let request_body = ChatCompletionRequest {
model: self.config.model.clone(),
temperature: self.config.temperature,
response_format: None,
messages: vec![
RequestMessage {
role: "system",
content: system_prompt,
},
RequestMessage {
role: "user",
content: user_prompt,
},
],
extra_fields: self.build_request_extra_fields(),
};
let mut attempt = 0;
loop {
let mut request = self
.client
.post(&endpoint)
.headers(headers.clone())
.json(&request_body);
if let Some(api_key) = self.config.api_key.as_deref() {
if !api_key.trim().is_empty() {
request = request.bearer_auth(api_key);
}
}
match request.send().await {
Ok(response) if response.status().is_success() => {
let payload: ChatCompletionResponse = response.json().await?;
let content = payload
.choices
.into_iter()
.next()
.ok_or_else(|| {
TranslatorError::ProviderProtocol(
"provider returned no completion choices".to_owned(),
)
})?
.message
.into_text()?;
let items = match prompt::parse_numbered_response(&batch, &content) {
Ok(items) => items,
Err(numbered_error) => {
let structured = parse_structured_response(&content)
.and_then(|payload| map_structured_response(&batch, payload));
match structured {
Ok(items) => items,
Err(json_error) => {
return Err(TranslatorError::ProviderProtocol(format!(
"failed to parse provider response as numbered lines ({numbered_error}) or JSON ({json_error})"
)));
}
}
}
};
return Ok(TranslationBatchOutput { items });
}
Ok(response) => {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if attempt < self.config.max_retries
&& (status.as_u16() == 429 || status.is_server_error())
{
attempt += 1;
sleep(retry_delay(attempt)).await;
continue;
}
return Err(TranslatorError::ProviderProtocol(format!(
"provider returned {status}: {body}"
)));
}
Err(_error) if attempt < self.config.max_retries => {
attempt += 1;
sleep(retry_delay(attempt)).await;
continue;
}
Err(error) => {
return Err(TranslatorError::ProviderTransport(format!(
"request to {endpoint} failed after {} attempt(s): {}",
attempt + 1,
describe_reqwest_error(&error)
)));
}
}
}
}
}
impl OpenAiCompatibleProvider {
fn supports_known_thinking_control(&self) -> bool {
self.is_nvidia_stepfun()
}
fn is_nvidia_stepfun(&self) -> bool {
self.config.base_url.contains("integrate.api.nvidia.com")
&& self.config.model.starts_with("stepfun-ai/")
}
pub fn is_gemini(&self) -> bool {
self.config
.base_url
.contains("generativelanguage.googleapis.com")
}
fn build_request_extra_fields(&self) -> BTreeMap<String, Value> {
let mut extra_fields = BTreeMap::new();
if self.is_nvidia_stepfun() {
match self.config.thinking_mode {
ThinkingMode::Off => {
extra_fields.insert(
"chat_template_kwargs".to_owned(),
json!({ "thinking": false }),
);
}
ThinkingMode::On => {
extra_fields.insert(
"chat_template_kwargs".to_owned(),
json!({ "thinking": true }),
);
}
ThinkingMode::Auto => {}
}
}
extra_fields
}
}
fn retry_delay(attempt: u32) -> Duration {
Duration::from_millis(250 * u64::from(attempt))
}
fn describe_reqwest_error(error: &reqwest::Error) -> String {
let mut parts = vec![error.to_string()];
let mut source = error.source();
while let Some(next) = source {
let message = next.to_string();
if !message.is_empty() && !parts.iter().any(|part| part == &message) {
parts.push(message);
}
source = next.source();
}
parts.join(": ")
}
fn parse_structured_response(
content: &str,
) -> Result<StructuredTranslationPayload, TranslatorError> {
let payload = serde_json::from_str(content).or_else(|_| {
let trimmed = content.trim();
let start = trimmed.find('{').ok_or_else(|| {
TranslatorError::ProviderProtocol("provider response did not contain JSON".to_owned())
})?;
let end = trimmed.rfind('}').ok_or_else(|| {
TranslatorError::ProviderProtocol("provider response did not contain JSON".to_owned())
})?;
serde_json::from_str::<StructuredTranslationPayload>(&trimmed[start..=end])
.map_err(TranslatorError::from)
})?;
Ok(payload)
}
fn map_structured_response(
batch: &TranslationBatch,
payload: StructuredTranslationPayload,
) -> Result<Vec<TranslatedItem>, TranslatorError> {
if payload.translations.len() != batch.items.len() {
return Err(TranslatorError::ProviderProtocol(format!(
"provider returned {} JSON translations for {} requested cues",
payload.translations.len(),
batch.items.len()
)));
}
Ok(payload
.translations
.into_iter()
.map(|item| TranslatedItem {
id: item.id,
text: item.text,
})
.collect())
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
temperature: f32,
messages: Vec<RequestMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
#[serde(flatten)]
extra_fields: BTreeMap<String, Value>,
}
#[derive(Debug, Serialize)]
struct RequestMessage {
role: &'static str,
content: String,
}
#[derive(Debug, Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
response_type: &'static str,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
choices: Vec<ChatChoice>,
}
#[derive(Debug, Deserialize)]
struct ChatChoice {
message: ChatMessage,
}
#[derive(Debug, Deserialize)]
struct ChatMessage {
content: Option<ChatMessageContent>,
reasoning: Option<String>,
reasoning_content: Option<String>,
}
impl ChatMessage {
fn into_text(self) -> Result<String, TranslatorError> {
if let Some(content) = self.content {
let text = content.into_text();
if !text.trim().is_empty() {
return Ok(text);
}
}
if let Some(reasoning_content) = self.reasoning_content {
if !reasoning_content.trim().is_empty() {
return Ok(reasoning_content);
}
}
if let Some(reasoning) = self.reasoning {
if !reasoning.trim().is_empty() {
return Ok(reasoning);
}
}
Err(TranslatorError::ProviderProtocol(
"provider returned a completion message without usable text".to_owned(),
))
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ChatMessageContent {
Text(String),
Parts(Vec<ChatMessagePart>),
}
impl ChatMessageContent {
fn into_text(self) -> String {
match self {
Self::Text(text) => text,
Self::Parts(parts) => parts
.into_iter()
.filter_map(|part| part.text)
.collect::<Vec<_>>()
.join("\n"),
}
}
}
#[derive(Debug, Deserialize)]
struct ChatMessagePart {
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StructuredTranslationPayload {
translations: Vec<StructuredTranslationItem>,
}
#[derive(Debug, Deserialize)]
struct StructuredTranslationItem {
id: String,
text: String,
}
#[cfg(test)]
mod tests {
use serde_json::json;
use crate::domain::{ProviderConfig, ThinkingMode};
use super::{ChatCompletionResponse, OpenAiCompatibleProvider, parse_structured_response};
#[test]
fn extracts_json_from_wrapped_response() {
let payload = parse_structured_response(
"```json\n{\"translations\":[{\"id\":\"cue-1\",\"text\":\"hola\"}]}\n```",
)
.expect("should parse embedded JSON");
assert_eq!(payload.translations.len(), 1);
assert_eq!(payload.translations[0].text, "hola");
}
#[test]
fn uses_reasoning_content_when_message_content_is_null() {
let response: ChatCompletionResponse = serde_json::from_str(
r#"{
"choices": [
{
"message": {
"content": null,
"reasoning": "{\n \"translations\": [{\n \"id\": \"cue-1\",\n \"text\": \"Olá\"\n }]\n}",
"reasoning_content": "{\n \"translations\": [{\n \"id\": \"cue-1\",\n \"text\": \"Olá\"\n }]\n}"
}
}
]
}"#,
)
.expect("response should deserialize");
let content = response
.choices
.into_iter()
.next()
.expect("choice should exist")
.message
.into_text()
.expect("reasoning content should be usable text");
let payload = parse_structured_response(&content).expect("JSON payload should parse");
assert_eq!(payload.translations.len(), 1);
assert_eq!(payload.translations[0].text, "Olá");
}
#[test]
fn adds_known_thinking_control_for_nvidia_stepfun() {
let provider = OpenAiCompatibleProvider::new(ProviderConfig {
base_url: "https://integrate.api.nvidia.com/v1".to_owned(),
model: "stepfun-ai/step-3.5-flash".to_owned(),
thinking_mode: ThinkingMode::Off,
..ProviderConfig::default()
})
.expect("provider should build");
let extra_fields = provider.build_request_extra_fields();
assert_eq!(
extra_fields.get("chat_template_kwargs"),
Some(&json!({ "thinking": false }))
);
}
#[test]
fn omits_thinking_control_for_unknown_provider() {
let provider = OpenAiCompatibleProvider::new(ProviderConfig::default())
.expect("provider should build");
assert!(provider.build_request_extra_fields().is_empty());
}
#[test]
fn detects_gemini_provider_by_base_url() {
let provider = OpenAiCompatibleProvider::new(ProviderConfig {
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".to_owned(),
model: "gemini-2.0-flash".to_owned(),
..ProviderConfig::default()
})
.expect("provider should build");
assert!(provider.is_gemini(), "should detect Gemini by base_url");
}
#[test]
fn gemini_provider_returns_gemini_name() {
let provider = OpenAiCompatibleProvider::new(ProviderConfig {
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".to_owned(),
model: "gemini-2.0-flash".to_owned(),
..ProviderConfig::default()
})
.expect("provider should build");
use super::super::TranslationProvider;
assert_eq!(provider.name(), "gemini");
}
#[test]
fn non_gemini_provider_returns_openai_compatible_name() {
let provider = OpenAiCompatibleProvider::new(ProviderConfig::default())
.expect("provider should build");
use super::super::TranslationProvider;
assert_eq!(provider.name(), "openai-compatible");
}
#[test]
fn gemini_provider_emits_no_extra_fields() {
let provider = OpenAiCompatibleProvider::new(ProviderConfig {
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".to_owned(),
model: "gemini-2.0-flash".to_owned(),
..ProviderConfig::default()
})
.expect("provider should build");
assert!(
provider.build_request_extra_fields().is_empty(),
"Gemini OpenAI-compat endpoint does not require extra request fields"
);
}
#[tokio::test]
async fn gemini_provider_translates_batch_via_openai_compat_endpoint() {
use wiremock::matchers::{body_string_contains, header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::super::{TranslationBatch, TranslationBatchItem, TranslationProvider};
let server = MockServer::start().await;
let response = serde_json::json!({
"choices": [
{
"message": {
"content": "1: Olá\n2: Mundo"
}
}
]
});
Mock::given(method("POST"))
.and(path("/v1beta/openai/chat/completions"))
.and(header("authorization", "Bearer test-gemini-key"))
.and(body_string_contains("1: Hello"))
.respond_with(ResponseTemplate::new(200).set_body_json(response))
.expect(1)
.mount(&server)
.await;
let provider = OpenAiCompatibleProvider::new(ProviderConfig {
base_url: format!("{}/v1beta/openai", server.uri()),
model: "gemini-2.0-flash".to_owned(),
api_key: Some("test-gemini-key".to_owned()),
..ProviderConfig::default()
})
.expect("provider should build");
let batch = TranslationBatch {
source_language: None,
target_language: "Portuguese".to_owned(),
system_prompt: None,
items: vec![
TranslationBatchItem {
id: "cue-1".to_owned(),
text: "Hello".to_owned(),
},
TranslationBatchItem {
id: "cue-2".to_owned(),
text: "World".to_owned(),
},
],
};
let output = provider
.translate_batch(batch)
.await
.expect("translation should succeed");
assert_eq!(output.items.len(), 2);
assert_eq!(output.items[0].text, "Olá");
assert_eq!(output.items[1].text, "Mundo");
server.verify().await;
}
}