use crate::llm::traits::AiProvider;
use crate::llm::types::{ChatCompletionParams, ProviderResponse};
use crate::llm::utils::normalize_model_name;
use anyhow::Result;
use std::env;
#[derive(Debug, Clone)]
pub struct GoogleVertexProvider;
impl Default for GoogleVertexProvider {
fn default() -> Self {
Self::new()
}
}
impl GoogleVertexProvider {
pub fn new() -> Self {
Self
}
}
const GOOGLE_APPLICATION_CREDENTIALS_ENV: &str = "GOOGLE_APPLICATION_CREDENTIALS";
const GOOGLE_API_KEY_ENV: &str = "GOOGLE_API_KEY";
#[allow(dead_code)] const PRICING: &[(&str, f64, f64)] = &[
("gemini-3-pro", 2.00, 12.00),
("gemini-3-pro-preview", 2.00, 12.00),
("gemini-2.5-pro", 1.25, 5.00),
("gemini-2.5-flash", 0.075, 0.30),
("gemini-2.5-flash-lite", 0.10, 0.30),
("gemini-2.0-flash", 0.10, 0.40), ("gemini-2.0-flash-lite", 0.10, 0.30),
("gemini-2.0-flash-live", 0.35, 1.50), ("gemini-1.5-pro", 1.25, 5.00),
("gemini-1.5-flash", 0.075, 0.30),
("gemini-1.0-pro", 0.50, 1.50),
];
#[async_trait::async_trait]
impl AiProvider for GoogleVertexProvider {
fn name(&self) -> &str {
"google"
}
fn supports_model(&self, model: &str) -> bool {
let normalized = normalize_model_name(model);
normalized.starts_with("text-")
|| normalized.starts_with("chat-")
|| normalized.contains("gemini")
|| normalized.contains("palm")
|| normalized.contains("text-bison")
|| normalized.contains("chat-bison")
}
fn get_api_key(&self) -> Result<String> {
if let Ok(key) = env::var(GOOGLE_API_KEY_ENV) {
Ok(key)
} else if let Ok(_credentials) = env::var(GOOGLE_APPLICATION_CREDENTIALS_ENV) {
Ok("service_account_auth".to_string()) } else {
Err(anyhow::anyhow!(
"Google authentication not found. Set either {} or {}",
GOOGLE_API_KEY_ENV,
GOOGLE_APPLICATION_CREDENTIALS_ENV
))
}
}
fn supports_caching(&self, model: &str) -> bool {
let normalized = normalize_model_name(model);
normalized.contains("gemini-1.5")
|| normalized.contains("gemini-2")
|| normalized.contains("gemini-3")
}
fn supports_vision(&self, model: &str) -> bool {
normalize_model_name(model).contains("gemini")
}
fn get_max_input_tokens(&self, model: &str) -> usize {
let normalized = normalize_model_name(model);
if normalized.contains("gemini-3") {
1_048_576 } else if normalized.contains("gemini-2") {
2_000_000 } else if normalized.contains("gemini-1.5") {
1_000_000 } else if normalized.contains("gemini-1.0") || normalized.contains("bison-32k") {
32_768 } else if normalized.contains("bison") {
8_192 } else {
32_768 }
}
async fn chat_completion(&self, _params: ChatCompletionParams) -> Result<ProviderResponse> {
Err(anyhow::anyhow!(
"Google Vertex AI provider not fully implemented in octolib"
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supports_model() {
let provider = GoogleVertexProvider::new();
assert!(provider.supports_model("gemini-1.5-pro"));
assert!(provider.supports_model("gemini-2.0-flash"));
assert!(provider.supports_model("gemini-1.0-pro"));
assert!(provider.supports_model("text-bison"));
assert!(!provider.supports_model("gpt-4"));
assert!(!provider.supports_model("claude-3"));
}
#[test]
fn test_supports_model_case_insensitive() {
let provider = GoogleVertexProvider::new();
assert!(provider.supports_model("GEMINI-1.5-PRO"));
assert!(provider.supports_model("GEMINI-2.0-FLASH"));
assert!(provider.supports_model("Gemini-1.5-Pro"));
assert!(provider.supports_model("GEMINI-1.0-pro"));
}
#[test]
fn test_supports_caching_case_insensitive() {
let provider = GoogleVertexProvider::new();
assert!(provider.supports_caching("gemini-1.5-pro"));
assert!(provider.supports_caching("gemini-2.0-flash"));
assert!(provider.supports_caching("GEMINI-1.5-PRO"));
assert!(provider.supports_caching("Gemini-2.0-Flash"));
}
}