use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use once_cell::sync::Lazy;
use regex::Regex;
#[cfg(feature = "gemini-llm")]
use crate::llm::TokenUsage;
use crate::llm::{
BaseLlm, EnvTokenProvider, GcloudTokenProvider, LlmError, LlmRequest, LlmResponse,
TokenProvider,
};
use crate::utils::variant::{get_google_llm_variant, GoogleLlmVariant};
#[derive(Default)]
pub struct GeminiLlmParams {
pub model: Option<String>,
pub api_key: Option<String>,
pub vertexai: Option<bool>,
pub project: Option<String>,
pub location: Option<String>,
pub headers: Option<HashMap<String, String>>,
pub token_provider: Option<Arc<dyn TokenProvider>>,
}
pub struct GeminiLlm {
model: String,
variant: GoogleLlmVariant,
#[allow(dead_code)]
params: GeminiLlmParams,
#[allow(dead_code)]
token_provider: Arc<dyn TokenProvider>,
#[cfg(feature = "gemini-llm")]
client: rs_genai::prelude::Client,
}
static SUPPORTED_PATTERNS: Lazy<Vec<Regex>> = Lazy::new(|| {
vec![
Regex::new(r"^gemini-.*$").unwrap(),
Regex::new(r"^projects/.*/endpoints/.*$").unwrap(),
Regex::new(r"^projects/.*/models/gemini.*$").unwrap(),
]
});
impl GeminiLlm {
pub fn new(mut params: GeminiLlmParams) -> Self {
let model = params
.model
.clone()
.unwrap_or_else(|| "gemini-2.5-flash".to_string());
let variant = if let Some(true) = params.vertexai {
GoogleLlmVariant::VertexAi
} else if let Some(false) = params.vertexai {
GoogleLlmVariant::GeminiApi
} else {
get_google_llm_variant()
};
if params.api_key.is_none() && variant == GoogleLlmVariant::GeminiApi {
params.api_key = std::env::var("GOOGLE_GENAI_API_KEY")
.or_else(|_| std::env::var("GEMINI_API_KEY"))
.ok();
}
if variant == GoogleLlmVariant::VertexAi {
if params.project.is_none() {
params.project = std::env::var("GOOGLE_CLOUD_PROJECT").ok();
}
if params.location.is_none() {
params.location = std::env::var("GOOGLE_CLOUD_LOCATION").ok();
}
}
let token_provider: Arc<dyn TokenProvider> =
params.token_provider.take().unwrap_or_else(|| {
if variant == GoogleLlmVariant::VertexAi {
Arc::new(GcloudTokenProvider::new(std::time::Duration::from_secs(
45 * 60,
)))
} else {
Arc::new(EnvTokenProvider)
}
});
#[cfg(feature = "gemini-llm")]
let client = {
use rs_genai::prelude::*;
match variant {
GoogleLlmVariant::GeminiApi => {
let api_key = params.api_key.as_deref().unwrap_or("");
Client::from_api_key(api_key).model(GeminiModel::Custom(model.clone()))
}
GoogleLlmVariant::VertexAi => {
let project = params.project.as_deref().unwrap_or("").to_string();
let location = params
.location
.as_deref()
.unwrap_or("us-central1")
.to_string();
let tp = token_provider.clone();
Client::from_vertex_refreshable(project, location, move || tp.token())
.model(GeminiModel::Custom(model.clone()))
}
}
};
Self {
model,
variant,
params,
token_provider,
#[cfg(feature = "gemini-llm")]
client,
}
}
pub fn is_supported(model: &str) -> bool {
SUPPORTED_PATTERNS.iter().any(|re| re.is_match(model))
}
pub fn variant(&self) -> GoogleLlmVariant {
self.variant
}
fn preprocess_request(&self, _request: &mut LlmRequest) {
}
}
#[async_trait]
impl BaseLlm for GeminiLlm {
fn model_id(&self) -> &str {
&self.model
}
async fn generate(&self, mut request: LlmRequest) -> Result<LlmResponse, LlmError> {
self.preprocess_request(&mut request);
#[cfg(feature = "gemini-llm")]
{
use rs_genai::generate::GenerateContentConfig;
use rs_genai::prelude::*;
let mut config = if request.contents.is_empty() {
GenerateContentConfig::from_text("")
} else {
GenerateContentConfig::from_contents(std::mem::take(&mut request.contents))
};
if let Some(sys) = request.system_instruction.take() {
config = config.system_instruction(&sys);
}
if !request.tools.is_empty() {
config.tools = std::mem::take(&mut request.tools);
}
if let Some(temp) = request.temperature {
config = config.temperature(temp);
}
if let Some(max) = request.max_output_tokens {
config = config.max_output_tokens(max);
}
if request.response_mime_type.is_some() || request.response_json_schema.is_some() {
let gc = config
.generation_config
.get_or_insert_with(rs_genai::prelude::GenerationConfig::default);
if let Some(mime) = request.response_mime_type.take() {
gc.response_mime_type = Some(mime);
}
if let Some(schema) = request.response_json_schema.take() {
gc.response_json_schema = Some(schema);
}
}
let response = self
.client
.generate_content_with(config, None)
.await
.map_err(|e| LlmError::RequestFailed(e.to_string()))?;
let content = response
.candidates
.first()
.and_then(|c| c.content.clone())
.unwrap_or_else(|| Content {
role: Some(Role::Model),
parts: vec![],
});
let finish_reason = response
.candidates
.first()
.and_then(|c| c.finish_reason)
.map(|r| format!("{:?}", r));
let usage = response.usage_metadata.map(|u| TokenUsage {
prompt_tokens: u.prompt_token_count.unwrap_or(0),
completion_tokens: u.response_token_count.unwrap_or(0),
total_tokens: u.total_token_count.unwrap_or(0),
});
Ok(LlmResponse {
content,
finish_reason,
usage,
})
}
#[cfg(not(feature = "gemini-llm"))]
{
let _ = request;
Err(LlmError::RequestFailed(
"GeminiLlm requires the 'gemini-llm' feature flag \
(depends on rs-genai HTTP client)"
.into(),
))
}
}
async fn warm_up(&self) -> Result<(), LlmError> {
#[cfg(feature = "gemini-llm")]
{
use rs_genai::generate::GenerateContentConfig;
let config = GenerateContentConfig::from_text(".").max_output_tokens(1);
let _ = self.client.generate_content_with(config, None).await;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_model_is_gemini_2_5_flash() {
let llm = GeminiLlm::new(GeminiLlmParams::default());
assert_eq!(llm.model_id(), "gemini-2.5-flash");
}
#[test]
fn explicit_model() {
let llm = GeminiLlm::new(GeminiLlmParams {
model: Some("gemini-2.0-pro".into()),
..Default::default()
});
assert_eq!(llm.model_id(), "gemini-2.0-pro");
}
#[test]
fn variant_from_params_vertex() {
let llm = GeminiLlm::new(GeminiLlmParams {
vertexai: Some(true),
..Default::default()
});
assert_eq!(llm.variant(), GoogleLlmVariant::VertexAi);
}
#[test]
fn variant_from_params_gemini_api() {
let llm = GeminiLlm::new(GeminiLlmParams {
vertexai: Some(false),
..Default::default()
});
assert_eq!(llm.variant(), GoogleLlmVariant::GeminiApi);
}
#[test]
fn is_supported_gemini_models() {
assert!(GeminiLlm::is_supported("gemini-2.5-flash"));
assert!(GeminiLlm::is_supported("gemini-2.0-pro"));
assert!(GeminiLlm::is_supported("gemini-1.5-pro-001"));
}
#[test]
fn is_supported_non_gemini_models() {
assert!(!GeminiLlm::is_supported("gpt-4"));
assert!(!GeminiLlm::is_supported("claude-3-opus"));
assert!(!GeminiLlm::is_supported("llama-3"));
}
#[test]
fn is_supported_vertex_ai_resource_paths() {
assert!(GeminiLlm::is_supported(
"projects/my-project/endpoints/12345"
));
assert!(GeminiLlm::is_supported(
"projects/my-project/models/gemini-2.5-flash"
));
}
#[test]
fn model_id_returns_correct_string() {
let llm = GeminiLlm::new(GeminiLlmParams {
model: Some("gemini-2.5-flash-preview-04-17".into()),
..Default::default()
});
assert_eq!(llm.model_id(), "gemini-2.5-flash-preview-04-17");
}
#[test]
fn base_llm_is_object_safe() {
fn _assert_object_safe(_: &dyn BaseLlm) {}
}
#[test]
fn gemini_llm_is_send_sync() {
fn _assert_send_sync<T: Send + Sync>() {}
_assert_send_sync::<GeminiLlm>();
}
}