use async_trait::async_trait;
use reqwest::Client as HttpClient;
use serde::{Deserialize, Serialize};
use crate::error::LlmError;
use crate::traits::ModelListingCapability;
use crate::types::ModelInfo;
use super::types::GeminiConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeminiModel {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub version: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_token_limit: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_token_limit: Option<i32>,
#[serde(default, rename = "supportedGenerationMethods")]
pub supported_generation_methods: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListModelsResponse {
#[serde(default)]
pub models: Vec<GeminiModel>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_page_token: Option<String>,
}
#[derive(Debug, Clone)]
pub struct GeminiModels {
config: GeminiConfig,
http_client: HttpClient,
}
impl GeminiModels {
pub const fn new(config: GeminiConfig, http_client: HttpClient) -> Self {
Self {
config,
http_client,
}
}
fn convert_model(&self, model: GeminiModel) -> ModelInfo {
let id = model
.name
.strip_prefix("models/")
.unwrap_or(&model.name)
.to_string();
let mut capabilities = Vec::new();
if model
.supported_generation_methods
.contains(&"generateContent".to_string())
{
capabilities.push("chat".to_string());
}
if model
.supported_generation_methods
.contains(&"streamGenerateContent".to_string())
{
capabilities.push("streaming".to_string());
}
if id.contains("gemini") {
capabilities.extend_from_slice(&[
"vision".to_string(),
"function_calling".to_string(),
"code_execution".to_string(),
]);
}
let context_window = model.input_token_limit.unwrap_or_else(|| {
if id.contains("1.5-pro") {
2_000_000 } else if id.contains("1.5-flash") || id.contains("2.0") {
1_000_000 } else {
32_000 }
});
ModelInfo {
id,
name: Some(model.display_name.unwrap_or(model.name)),
description: model.description,
context_window: Some(context_window as u32),
max_output_tokens: model.output_token_limit.map(|t| t as u32),
capabilities,
input_cost_per_token: None,
output_cost_per_token: None,
created: None,
owned_by: "Google".to_string(),
}
}
async fn fetch_all_models(&self) -> Result<Vec<GeminiModel>, LlmError> {
let mut all_models = Vec::new();
let mut page_token: Option<String> = None;
loop {
let mut url = format!("{}/models", self.config.base_url);
let mut params = Vec::new();
if let Some(token) = &page_token {
params.push(format!("pageToken={token}"));
}
params.push("pageSize=50".to_string());
if !params.is_empty() {
url.push('?');
url.push_str(¶ms.join("&"));
}
let response = self
.http_client
.get(&url)
.header("x-goog-api-key", &self.config.api_key)
.send()
.await
.map_err(|e| LlmError::HttpError(e.to_string()))?;
if !response.status().is_success() {
let status_code = response.status().as_u16();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::api_error(
status_code,
format!("Gemini API error: {status_code} - {error_text}"),
));
}
let list_response: ListModelsResponse = response.json().await.map_err(|e| {
LlmError::ParseError(format!("Failed to parse models response: {e}"))
})?;
all_models.extend(list_response.models);
if let Some(next_token) = list_response.next_page_token {
page_token = Some(next_token);
} else {
break;
}
}
Ok(all_models)
}
}
#[async_trait]
impl ModelListingCapability for GeminiModels {
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
let models = self.fetch_all_models().await?;
let generative_models: Vec<ModelInfo> = models
.into_iter()
.filter(|model| {
model
.supported_generation_methods
.contains(&"generateContent".to_string())
})
.map(|model| self.convert_model(model))
.collect();
Ok(generative_models)
}
async fn get_model(&self, model_id: String) -> Result<ModelInfo, LlmError> {
let full_model_name = if model_id.starts_with("models/") {
model_id
} else {
format!("models/{model_id}")
};
let url = crate::utils::url::join_url(&self.config.base_url, &full_model_name);
let response = self
.http_client
.get(&url)
.header("x-goog-api-key", &self.config.api_key)
.send()
.await
.map_err(|e| LlmError::HttpError(e.to_string()))?;
if !response.status().is_success() {
let status_code = response.status().as_u16();
let error_text = response.text().await.unwrap_or_default();
return Err(LlmError::api_error(
status_code,
format!("Gemini API error: {status_code} - {error_text}"),
));
}
let model: GeminiModel = response
.json()
.await
.map_err(|e| LlmError::ParseError(format!("Failed to parse model response: {e}")))?;
Ok(self.convert_model(model))
}
}
pub fn get_default_models() -> Vec<String> {
vec![
"gemini-2.5-pro".to_string(),
"gemini-2.5-flash".to_string(),
"gemini-2.5-flash-lite".to_string(),
"gemini-2.0-flash".to_string(),
"gemini-2.0-flash-lite".to_string(),
"gemini-1.5-flash".to_string(),
"gemini-1.5-flash-8b".to_string(),
"gemini-1.5-pro".to_string(),
]
}
pub fn model_supports_capability(model_id: &str, capability: &str) -> bool {
match capability {
"chat" => true, "streaming" => true, "vision" => model_id.contains("gemini"), "function_calling" => model_id.contains("gemini"), "code_execution" => model_id.contains("gemini"), "thinking" => {
model_id.contains("gemini-2.5")
|| model_id.contains("gemini-2.0")
|| model_id.contains("exp")
} "audio_generation" => {
model_id.contains("tts")
|| model_id.contains("live")
|| model_id.contains("native-audio")
} "image_generation" => model_id.contains("image-generation"), "live_api" => model_id.contains("live"), _ => false,
}
}
pub fn get_model_context_window(model_id: &str) -> u32 {
if model_id.contains("2.5-pro")
|| model_id.contains("2.5-flash")
|| model_id.contains("2.0-flash")
{
1_048_576 } else if model_id.contains("1.5-pro") {
2_097_152 } else if model_id.contains("1.5-flash") {
1_048_576 } else {
128_000 }
}
pub fn get_model_max_output_tokens(model_id: &str) -> u32 {
if model_id.contains("2.5-pro") || model_id.contains("2.5-flash") {
65_536 } else if model_id.contains("2.0-flash")
|| model_id.contains("1.5-pro")
|| model_id.contains("1.5-flash")
{
8192 } else if model_id.contains("tts") {
16_000 } else {
8192 }
}