use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct GoogleVertexProvider {
base: BaseProvider,
project_id: String,
region: String,
service_account_key: Option<String>,
access_token: Option<String>,
pricing_cache: HashMap<String, ModelPricing>,
}
impl GoogleVertexProvider {
pub async fn new(config: &ProviderConfig) -> Result<Self> {
let base = BaseProvider::new(config)?;
let project_id = config.project.as_ref().cloned().unwrap_or_default();
let region = config
.organization
.as_ref()
.cloned()
.unwrap_or_else(|| "us-central1".to_string());
let service_account_key = std::env::var("GOOGLE_APPLICATION_CREDENTIALS").ok();
let base_url = config
.base_url
.clone()
.unwrap_or_else(|| format!("https://{}-aiplatform.googleapis.com", region));
let provider = Self {
base: BaseProvider { base_url, ..base },
project_id,
region,
service_account_key,
access_token: None,
pricing_cache: Self::initialize_pricing_cache(),
};
info!(
"Google Vertex AI provider '{}' initialized successfully",
config.name
);
Ok(provider)
}
fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
let mut cache = HashMap::new();
cache.insert(
"gemini-pro".to_string(),
ModelPricing {
model: "gemini-pro".to_string(),
input_cost_per_1k: 0.0005,
output_cost_per_1k: 0.0015,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"gemini-pro-vision".to_string(),
ModelPricing {
model: "gemini-pro-vision".to_string(),
input_cost_per_1k: 0.0025,
output_cost_per_1k: 0.0075,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"gemini-1.5-pro".to_string(),
ModelPricing {
model: "gemini-1.5-pro".to_string(),
input_cost_per_1k: 0.0035,
output_cost_per_1k: 0.0105,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"gemini-1.5-flash".to_string(),
ModelPricing {
model: "gemini-1.5-flash".to_string(),
input_cost_per_1k: 0.000075,
output_cost_per_1k: 0.0003,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache.insert(
"textembedding-gecko".to_string(),
ModelPricing {
model: "textembedding-gecko".to_string(),
input_cost_per_1k: 0.0001,
output_cost_per_1k: 0.0,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
},
);
cache
}
async fn get_access_token(&mut self) -> Result<String> {
if let Some(token) = &self.access_token {
return Ok(token.clone());
}
if let Some(service_account_key) = &self.service_account_key {
let token = self
.authenticate_with_service_account(service_account_key)
.await?;
self.access_token = Some(token.clone());
Ok(token)
} else {
self.get_default_credentials().await
}
}
async fn authenticate_with_service_account(&self, key_json: &str) -> Result<String> {
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize)]
struct ServiceAccountKey {
client_email: String,
private_key: String,
token_uri: String,
}
#[derive(Debug, Serialize)]
struct Claims {
iss: String,
scope: String,
aud: String,
exp: usize,
iat: usize,
}
let key: ServiceAccountKey = serde_json::from_str(key_json).map_err(|e| {
ProviderError::Authentication(format!("Invalid service account key: {}", e))
})?;
let now = chrono::Utc::now().timestamp() as usize;
let claims = Claims {
iss: key.client_email,
scope: "https://www.googleapis.com/auth/cloud-platform".to_string(),
aud: key.token_uri.clone(),
exp: now + 3600, iat: now,
};
let header = Header::new(Algorithm::RS256);
let encoding_key = EncodingKey::from_rsa_pem(key.private_key.as_bytes())
.map_err(|e| ProviderError::Authentication(format!("Invalid private key: {}", e)))?;
let jwt = encode(&header, &claims, &encoding_key)
.map_err(|e| ProviderError::Authentication(format!("Failed to create JWT: {}", e)))?;
let params = [
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", &jwt),
];
let response = self
.base
.client
.post(&key.token_uri)
.form(¶ms)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let token_response: serde_json::Value = response
.json()
.await
.map_err(|e| ProviderError::Parsing(e.to_string()))?;
token_response
.get("access_token")
.and_then(|t| t.as_str())
.map(|t| t.to_string())
.ok_or_else(|| {
ProviderError::Authentication("Failed to get access token".to_string()).into()
})
}
async fn get_default_credentials(&self) -> Result<String> {
let metadata_url = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token";
let response = self
.base
.client
.get(metadata_url)
.header("Metadata-Flavor", "Google")
.send()
.await;
if let Ok(resp) = response {
if resp.status().is_success() {
let token_response: serde_json::Value = resp
.json()
.await
.map_err(|e| ProviderError::Parsing(e.to_string()))?;
if let Some(token) = token_response.get("access_token").and_then(|t| t.as_str()) {
return Ok(token.to_string());
}
}
}
Err(ProviderError::Authentication(
"No valid credentials found. Please provide service account key or run on GCP"
.to_string(),
)
.into())
}
fn convert_messages_to_vertex(&self, messages: &[ChatMessage]) -> serde_json::Value {
let mut contents = Vec::new();
for message in messages {
let role = match message.role {
MessageRole::User => "user",
MessageRole::Assistant => "model",
MessageRole::System => "user", _ => "user",
};
let parts = match &message.content {
Some(MessageContent::Text(text)) => {
vec![json!({"text": text})]
}
Some(MessageContent::Parts(parts)) => {
parts.iter().map(|part| match part {
ContentPart::Text { text } => json!({"text": text}),
ContentPart::ImageUrl { image_url } => {
json!({
"inline_data": {
"mime_type": "image/jpeg",
"data": image_url.url.strip_prefix("data:image/jpeg;base64,").unwrap_or(&image_url.url)
}
})
}
ContentPart::Audio { .. } => {
json!({"text": "[Audio content not supported]"})
}
}).collect()
}
None => {
vec![json!({"text": ""})]
}
};
contents.push(json!({
"role": role,
"parts": parts
}));
}
json!({
"contents": contents
})
}
fn convert_vertex_response_to_openai(
&self,
vertex_response: serde_json::Value,
model: &str,
) -> Result<ChatCompletionResponse> {
let candidates = vertex_response
.get("candidates")
.and_then(|c| c.as_array())
.ok_or_else(|| ProviderError::Parsing("No candidates in response".to_string()))?;
let first_candidate = candidates
.first()
.ok_or_else(|| ProviderError::Parsing("No candidate in response".to_string()))?;
let content = first_candidate
.get("content")
.and_then(|c| c.get("parts"))
.and_then(|p| p.as_array())
.and_then(|arr| arr.first())
.and_then(|part| part.get("text"))
.and_then(|text| text.as_str())
.unwrap_or("")
.to_string();
let finish_reason = first_candidate
.get("finishReason")
.and_then(|r| r.as_str())
.map(|r| match r {
"STOP" => "stop",
"MAX_TOKENS" => "length",
"SAFETY" => "content_filter",
_ => "stop",
})
.unwrap_or("stop")
.to_string();
let usage_metadata = vertex_response.get("usageMetadata");
let usage = if let Some(metadata) = usage_metadata {
Usage {
prompt_tokens: metadata
.get("promptTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
completion_tokens: metadata
.get("candidatesTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
total_tokens: metadata
.get("totalTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32,
prompt_tokens_details: None,
completion_tokens_details: None,
}
} else {
Usage::default()
};
Ok(ChatCompletionResponse {
id: format!("chatcmpl-vertex-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: model.to_string(),
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: MessageRole::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
finish_reason: Some(finish_reason),
logprobs: None,
}],
usage: Some(usage),
system_fingerprint: None,
})
}
}
#[async_trait]
impl Provider for GoogleVertexProvider {
fn name(&self) -> &str {
&self.base.name
}
fn provider_type(&self) -> ProviderType {
ProviderType::Custom("google_vertex".to_string())
}
async fn supports_model(&self, model: &str) -> bool {
self.base.is_model_supported(model)
|| model.starts_with("gemini")
|| model.starts_with("text-bison")
|| model.starts_with("chat-bison")
|| model.starts_with("textembedding")
}
async fn supports_images(&self) -> bool {
true }
async fn supports_embeddings(&self) -> bool {
true }
async fn supports_streaming(&self) -> bool {
true }
async fn list_models(&self) -> Result<Vec<Model>> {
let known_models = vec![
"gemini-pro",
"gemini-pro-vision",
"gemini-1.5-pro",
"gemini-1.5-flash",
"text-bison",
"text-bison-32k",
"chat-bison",
"chat-bison-32k",
"textembedding-gecko",
"textembedding-gecko-multilingual",
];
let models = known_models
.into_iter()
.map(|model| Model {
id: model.to_string(),
object: "model".to_string(),
created: chrono::Utc::now().timestamp() as u64,
owned_by: "google".to_string(),
})
.collect();
Ok(models)
}
async fn health_check(&self) -> Result<()> {
debug!("Performing Google Vertex AI health check");
let mut provider = self.clone();
provider.get_access_token().await?;
Ok(())
}
async fn chat_completion(
&self,
request: ChatCompletionRequest,
_context: RequestContext,
) -> Result<ChatCompletionResponse> {
debug!(
"Google Vertex AI chat completion for model: {}",
request.model
);
let mut provider = self.clone();
let access_token = provider.get_access_token().await?;
let body = provider.convert_messages_to_vertex(&request.messages);
let mut final_body = body;
let mut generation_config = json!({});
if let Some(max_tokens) = request.max_tokens {
generation_config["maxOutputTokens"] = json!(max_tokens);
}
if let Some(temperature) = request.temperature {
generation_config["temperature"] = json!(temperature);
}
if let Some(top_p) = request.top_p {
generation_config["topP"] = json!(top_p);
}
if !generation_config.as_object().unwrap().is_empty() {
final_body["generationConfig"] = generation_config;
}
let url = format!(
"{}/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent",
self.base.base_url, self.project_id, self.region, request.model
);
let response = self
.base
.client
.post(&url)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json")
.json(&final_body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 | 403 => ProviderError::Authentication(error_text),
429 => ProviderError::RateLimit(error_text),
404 => ProviderError::ModelNotFound(error_text),
400 => ProviderError::InvalidRequest(error_text),
_ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
}
.into());
}
let vertex_response: serde_json::Value = self.base.parse_json_response(response).await?;
provider.convert_vertex_response_to_openai(vertex_response, &request.model)
}
async fn completion(
&self,
request: CompletionRequest,
_context: RequestContext,
) -> Result<CompletionResponse> {
debug!("Google Vertex AI completion for model: {}", request.model);
let mut provider = self.clone();
let access_token = provider.get_access_token().await?;
let body = json!({
"instances": [{
"prompt": request.prompt
}],
"parameters": {
"maxOutputTokens": request.max_tokens.unwrap_or(512),
"temperature": request.temperature.unwrap_or(0.7),
"topP": request.top_p.unwrap_or(1.0)
}
});
let url = format!(
"{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
self.base.base_url, self.project_id, self.region, request.model
);
let response = self
.base
.client
.post(&url)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let vertex_response: serde_json::Value = self.base.parse_json_response(response).await?;
let text = vertex_response
.get("predictions")
.and_then(|p| p.as_array())
.and_then(|arr| arr.first())
.and_then(|pred| pred.get("content"))
.and_then(|content| content.as_str())
.unwrap_or("")
.to_string();
Ok(CompletionResponse {
id: format!("cmpl-vertex-{}", uuid::Uuid::new_v4()),
object: "text_completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: request.model,
choices: vec![CompletionChoice {
text,
index: 0,
logprobs: None,
finish_reason: Some("stop".to_string()),
}],
usage: Some(Usage::default()),
})
}
async fn embedding(
&self,
request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse> {
debug!("Google Vertex AI embedding for model: {}", request.model);
let mut provider = self.clone();
let access_token = provider.get_access_token().await?;
let body = json!({
"instances": [{
"content": request.input
}]
});
let url = format!(
"{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
self.base.base_url, self.project_id, self.region, request.model
);
let response = self
.base
.client
.post(&url)
.header("Authorization", format!("Bearer {}", access_token))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| ProviderError::Network(e.to_string()))?;
let vertex_response: serde_json::Value = self.base.parse_json_response(response).await?;
let embedding_vec = vertex_response
.get("predictions")
.and_then(|p| p.as_array())
.and_then(|arr| arr.first())
.and_then(|pred| pred.get("embeddings"))
.and_then(|emb| emb.get("values"))
.and_then(|values| values.as_array())
.unwrap_or(&vec![])
.iter()
.filter_map(|v| v.as_f64())
.collect();
let embeddings = vec![EmbeddingObject {
object: "embedding".to_string(),
embedding: embedding_vec,
index: 0,
}];
Ok(EmbeddingResponse {
object: "list".to_string(),
data: embeddings,
model: request.model,
usage: EmbeddingUsage {
prompt_tokens: 0,
total_tokens: 0,
},
})
}
async fn image_generation(
&self,
_request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse> {
Err(ProviderError::InvalidRequest(
"Image generation not supported by Vertex AI text models".to_string(),
)
.into())
}
async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
if let Some(pricing) = self.pricing_cache.get(model) {
Ok(pricing.clone())
} else {
Ok(ModelPricing {
model: model.to_string(),
input_cost_per_1k: 0.0005,
output_cost_per_1k: 0.0015,
currency: "USD".to_string(),
updated_at: chrono::Utc::now(),
})
}
}
async fn calculate_cost(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64> {
let pricing = self.get_model_pricing(model).await?;
let input_cost = (input_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
let output_cost = (output_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
Ok(input_cost + output_cost)
}
}