use crate::load_balancer::tasks::TaskDefinition;
use crate::providers::instances::{LlmInstance, BaseInstance};
use crate::providers::types::{LlmRequest, LlmResponse, LlmStream, StreamChunk, TokenUsage, Message};
use crate::errors::{LlmError, LlmResult};
use crate::constants;
use async_trait::async_trait;
use reqwest::header;
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use log::debug;
use futures::StreamExt;
pub struct GoogleInstance {
base: BaseInstance,
}
#[derive(Serialize)]
struct GoogleGenerateContentRequest {
contents: Vec<GoogleContent>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "generationConfig")]
generation_config: Option<GoogleGenerationConfig>,
}
#[derive(Serialize, Deserialize)]
struct GoogleContent {
role: String,
parts: Vec<GooglePart>,
}
#[derive(Serialize, Deserialize)]
struct GooglePart {
text: String,
}
#[derive(Serialize, Default)]
struct GoogleGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "maxOutputTokens")]
max_output_tokens: Option<u32>,
}
#[derive(Deserialize)]
struct GoogleGenerateContentResponse {
candidates: Vec<GoogleCandidate>,
}
#[derive(Deserialize)]
struct GoogleCandidate {
content: GoogleContent,
#[serde(rename = "tokenCount")]
#[serde(default)]
token_count: u32, }
#[derive(Deserialize)]
struct GoogleStreamChunk {
candidates: Option<Vec<GoogleStreamCandidate>>,
}
#[derive(Deserialize)]
struct GoogleStreamCandidate {
content: Option<GoogleContent>,
#[serde(rename = "finishReason")]
finish_reason: Option<String>,
}
impl GoogleInstance {
pub fn new(api_key: String, model: String, supported_tasks: HashMap<String, TaskDefinition>, enabled: bool) -> Self {
let base = BaseInstance::new("google".to_string(), api_key, model, supported_tasks, enabled);
Self { base }
}
fn map_messages_to_contents(messages: &[Message]) -> LlmResult<Vec<GoogleContent>> {
let mut contents = Vec::new();
let mut system_prompt: Option<String> = None;
let mut first_user_message_index: Option<usize> = None;
for (_, msg) in messages.iter().enumerate() {
match msg.role.as_str() {
"system" => {
if system_prompt.is_some() {
return Err(LlmError::ApiError("Multiple system messages are not supported by Google provider mapping.".to_string()));
}
system_prompt = Some(msg.content.clone());
}
"user" | "model" | "assistant" => {
let role = if msg.role == "assistant" { "model" } else { &msg.role };
if role == "user" && first_user_message_index.is_none() {
first_user_message_index = Some(contents.len());
}
contents.push(GoogleContent {
role: role.to_string(),
parts: vec![GooglePart { text: msg.content.clone() }],
});
}
_ => {
log::warn!("Ignoring message with unknown role: {}", msg.role);
}
}
}
if let Some(sys_prompt) = &system_prompt {
if let Some(user_idx) = first_user_message_index {
if let Some(user_content) = contents.get_mut(user_idx) {
if let Some(part) = user_content.parts.get_mut(0) {
part.text = format!("{}\n\n{}", sys_prompt, part.text);
}
} else {
return Err(LlmError::ApiError("System message provided but no user message found.".to_string()));
}
} else {
return Err(LlmError::ApiError("System message provided but no user message found.".to_string()));
}
}
if contents.is_empty() {
return Err(LlmError::ApiError("No valid messages found for Google provider.".to_string()));
}
if contents[0].role != "user" {
return Err(LlmError::ApiError(format!("Google chat must start with a 'user' role message, found '{}'.", contents[0].role)));
}
Ok(contents)
}
}
#[async_trait]
impl LlmInstance for GoogleInstance {
async fn generate(&self, request: &LlmRequest) -> LlmResult<LlmResponse> {
if !self.base.is_enabled() {
return Err(LlmError::ProviderDisabled("Google".to_string()));
}
let model_name = self.base.model();
let api_key = self.base.api_key();
let url = format!(
"{}/v1beta/models/{}:generateContent?key={}",
constants::GOOGLE_API_ENDPOINT_PREFIX,
model_name,
api_key
);
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
let contents = Self::map_messages_to_contents(&request.messages)?;
let mut generation_config = GoogleGenerationConfig::default();
generation_config.temperature = request.temperature;
generation_config.max_output_tokens = request.max_tokens;
let google_request = GoogleGenerateContentRequest {
contents,
generation_config: Some(generation_config).filter(|gc| {
gc.temperature.is_some() || gc.max_output_tokens.is_some()
}),
};
let response = self.base.client()
.post(&url)
.headers(headers)
.json(&google_request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_json: Result<serde_json::Value, _> = response.json().await;
let error_details = match error_json {
Ok(json) => json.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| format!("Unknown error structure: {}", json)),
Err(_) => "Failed to parse error response body".to_string(),
};
return Err(LlmError::ApiError(format!(
"Google API error ({}): {}",
status, error_details
)));
}
let google_response: GoogleGenerateContentResponse = response.json().await
.map_err(|e| LlmError::ApiError(format!("Failed to parse Google JSON response: {}", e)))?;
if google_response.candidates.is_empty() {
return Err(LlmError::ApiError("No candidates returned from Google. Content may have been blocked.".to_string()));
}
let candidate = &google_response.candidates[0];
let combined_content = candidate.content.parts.iter()
.map(|part| part.text.clone())
.collect::<Vec<String>>()
.join("");
let usage = if candidate.token_count > 0 {
Some(TokenUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: candidate.token_count,
})
} else {
None
};
debug!("Google usage: {:?}", usage);
Ok(LlmResponse {
content: combined_content,
model: model_name.to_string(),
usage,
})
}
async fn generate_stream(&self, request: &LlmRequest) -> LlmResult<LlmStream> {
if !self.base.is_enabled() {
return Err(LlmError::ProviderDisabled("Google".to_string()));
}
let model_name = self.base.model();
let api_key = self.base.api_key();
let url = format!(
"{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
constants::GOOGLE_API_ENDPOINT_PREFIX,
model_name,
api_key
);
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
let contents = Self::map_messages_to_contents(&request.messages)?;
let mut generation_config = GoogleGenerationConfig::default();
generation_config.temperature = request.temperature;
generation_config.max_output_tokens = request.max_tokens;
let google_request = GoogleGenerateContentRequest {
contents,
generation_config: Some(generation_config).filter(|gc| {
gc.temperature.is_some() || gc.max_output_tokens.is_some()
}),
};
let response = self.base.client()
.post(&url)
.headers(headers)
.json(&google_request)
.send()
.await?;
let response_status = response.status();
if response_status.as_u16() == 429 {
let error_text = response.text().await
.unwrap_or_else(|_| "Rate limit exceeded".to_string());
return Err(LlmError::RateLimit(format!("Google rate limit: {}", error_text)));
}
if !response_status.is_success() {
let error_json: Result<serde_json::Value, _> = response.json().await;
let error_details = match error_json {
Ok(json) => json.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.map(|s| s.to_string())
.unwrap_or_else(|| format!("Unknown error: {}", json)),
Err(_) => "Failed to parse error response".to_string(),
};
return Err(LlmError::ApiError(format!("Google API error ({}): {}", response_status, error_details)));
}
let byte_stream = response.bytes_stream();
let chunk_stream = byte_stream
.map(|result| result.map_err(|e| LlmError::RequestError(e)))
.flat_map(|result| {
match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let chunks: Vec<Result<StreamChunk, LlmError>> = text
.lines()
.filter_map(|line| {
let line = line.trim();
if line.starts_with("data: ") {
let data = &line[6..];
match serde_json::from_str::<GoogleStreamChunk>(data) {
Ok(chunk) => {
if let Some(candidates) = chunk.candidates {
if let Some(candidate) = candidates.first() {
let is_final = candidate.finish_reason.is_some();
if let Some(content) = &candidate.content {
let text = content.parts.iter()
.map(|p| p.text.clone())
.collect::<Vec<_>>()
.join("");
return Some(Ok(StreamChunk {
content: text,
model: None,
is_final,
usage: None,
}));
}
}
}
None
}
Err(e) => {
debug!("Failed to parse Google streaming chunk: {}", e);
None
}
}
} else {
None
}
})
.collect();
futures::stream::iter(chunks)
}
Err(e) => futures::stream::iter(vec![Err(e)])
}
});
Ok(Box::pin(chunk_stream))
}
fn supports_streaming(&self) -> bool {
true
}
fn get_name(&self) -> &str {
self.base.name()
}
fn get_model(&self) -> &str {
self.base.model()
}
fn get_supported_tasks(&self) -> &HashMap<String, TaskDefinition> {
&self.base.supported_tasks()
}
fn is_enabled(&self) -> bool {
self.base.is_enabled()
}
}