use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct GoogleProvider {
base: BaseProvider,
pricing_cache: HashMap<String, ModelPricing>,
}
impl GoogleProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1".to_string());
let provider = Self {
base: BaseProvider { base_url, ..base },
pricing_cache: Self::initialize_pricing_cache(),
};
info!("Google provider '{}' initialized successfully", config.name);
Ok(provider)
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
cache.insert(
"gemini-pro".to_string(),
ModelPricing {
model: "gemini-pro".to_string(),
input_cost_per_1k: 0.00025,
output_cost_per_1k: 0.0005,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"gemini-ultra".to_string(),
ModelPricing {
model: "gemini-ultra".to_string(),
input_cost_per_1k: 0.00125,
output_cost_per_1k: 0.00375,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache
}
fn convert_messages_to_google(&self, messages: &[ChatMessage]) -> serde_json::Value {
let mut google_messages = Vec::new();
let mut system_message = None;
for message in messages {
match message.role {
MessageRole::System => {
if let Some(MessageContent::Text(text)) = &message.content {
system_message = Some(text.clone());
}
}
MessageRole::User => {
google_messages.push(json!({
"role": "user",
"parts": [
{
"text": self.extract_text_content(message.content.as_ref())
}
]
}));
}
MessageRole::Assistant => {
google_messages.push(json!({
"role": "model",
"parts": [
{
"text": self.extract_text_content(message.content.as_ref())
}
]
}));
}
_ => {}
}
}
if let Some(system) = system_message {
google_messages.insert(
0,
json!({
"role": "user",
"parts": [
{
"text": format!("System: {}", system)
}
]
}),
);
}
json!({ "contents": google_messages })
}
fn extract_text_content(&self, content: Option<&MessageContent>) -> String {
match content {
Some(MessageContent::Text(text)) => text.clone(),
Some(MessageContent::Parts(parts)) => parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<String>>()
.join("\n"),
None => String::new(),
}
}
fn convert_google_response_to_openai(
&self,
google_response: serde_json::Value,
model: &str,
) -> Result<ChatCompletionResponse> {
let content = google_response
.get("candidates")
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|candidate| candidate.get("content"))
.and_then(|content| content.get("parts"))
.and_then(|parts| parts.as_array())
.and_then(|arr| arr.first())
.and_then(|part| part.get("text"))
.and_then(|text| text.as_str())
.unwrap_or("")
.to_string();
let usage = Usage {
prompt_tokens: 0, completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
};
Ok(ChatCompletionResponse {
id: format!("chatcmpl-google-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: model.to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: MessageRole::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
finish_reason: Some("stop".to_string()),
logprobs: None,
}],
usage: Some(usage),
system_fingerprint: None,
})
}
}
#[async_trait]
impl Provider for GoogleProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Google
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model) || model.starts_with("gemini-")
}
async fn supports_images(&self) -> bool {
true }
async fn supports_embeddings(&self) -> bool {
true
}
async fn supports_streaming(&self) -> bool {
true
}
async fn list_models(&self) -> Result<Vec<Model>> {
let known_models = vec!["gemini-pro", "gemini-ultra", "gemini-pro-vision"];
let models = known_models
.into_iter()
.map(|model| Model {
id: model.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "google".to_string(),
})
.collect();
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing Google health check");
Ok(())
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!("Google chat completion for model: {}", request.model);
let model_name = request
.model
.split('/')
.next_back()
.unwrap_or(&request.model);
let endpoint = format!(
"models/{}:generateContent?key={}",
model_name, self.base.api_key
);
let mut body = self.convert_messages_to_google(&request.messages);
let mut generation_config = json!({});
if let Some(max_tokens) = request.max_tokens {
generation_config["maxOutputTokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
generation_config["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
generation_config["topP"] = json!(top_p);
}
if !generation_config.as_object().unwrap().is_empty() {
body["generationConfig"] = generation_config;
}
let response = self
.base
.make_request(reqwest::Method::POST, &endpoint, Some(body))
.await?;
let google_response: serde_json::Value = self.base.parse_json_response(response).await?;
self.convert_google_response_to_openai(google_response, &request.model)
}
async fn completion(
&self,
_request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
Err(ProviderError::InvalidRequest(
"Google does not support legacy completion endpoint".to_string(),
)
.into())
}
async fn embedding(
&self,
_request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
Err(
ProviderError::InvalidRequest("Google embedding not implemented yet".to_string())
.into(),
)
}
async fn image_generation(
&self,
_request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
Err(ProviderError::InvalidRequest(
"Google image generation not implemented yet".to_string(),
)
.into())
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.0005, output_cost_per_1k: 0.0015,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
Ok(input_cost + output_cost)
}
}