use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, instrument};
use crate::error::{LlmError, Result};
use crate::traits::{
ChatMessage,
ChatRole,
CompletionOptions,
EmbeddingProvider,
LLMProvider,
LLMResponse,
StreamChunk,
StreamUsage,
ToolCall,
ToolChoice,
ToolDefinition, };
const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
const DEFAULT_GEMINI_MODEL: &str = "gemini-2.5-flash";
const DEFAULT_EMBEDDING_MODEL: &str = "gemini-embedding-001";
#[derive(Debug, Clone)]
pub enum GeminiEndpoint {
GoogleAI { api_key: String },
VertexAI {
project_id: String,
region: String,
access_token: String,
},
}
#[derive(Debug, Default)]
struct CacheState {
content_id: Option<String>,
system_hash: Option<u64>,
}
#[derive(Debug)]
pub struct GeminiProvider {
client: Client,
endpoint: GeminiEndpoint,
model: String,
embedding_model: String,
max_context_length: usize,
embedding_dimension: usize,
cache_ttl: String,
cache_state: tokio::sync::RwLock<CacheState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Blob {
pub mime_type: String, pub data: String, }
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct Part {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub inline_data: Option<Blob>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_response: Option<FunctionResponse>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub thought: Option<bool>,
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GeminiTool {
pub function_declarations: Vec<FunctionDeclaration>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
pub mode: String, #[serde(skip_serializing_if = "Option::is_none")]
pub allowed_function_names: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: FunctionCallingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponse {
pub name: String,
pub response: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Content {
pub parts: Vec<Part>,
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_mime_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ThinkingConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct ThinkingConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_thoughts: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_level: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_budget: Option<i32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: String,
pub threshold: String,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct CreateCachedContentRequest {
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
contents: Option<Vec<Content>>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<serde_json::Value>>,
ttl: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct CachedContentResponse {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[allow(dead_code)]
usage_metadata: Option<UsageMetadata>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GenerateContentRequest {
contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<GenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
safety_settings: Option<Vec<SafetySetting>>,
#[serde(skip_serializing_if = "Option::is_none")]
cached_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<ToolConfig>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Candidate {
content: Content,
finish_reason: Option<String>,
#[serde(default)]
safety_ratings: Vec<serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
struct UsageMetadata {
#[serde(default)]
prompt_token_count: usize,
#[serde(default)]
candidates_token_count: usize,
#[serde(default)]
total_token_count: usize,
#[serde(default)]
cached_content_token_count: usize,
#[serde(default)]
thoughts_token_count: usize,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GenerateContentResponse {
candidates: Option<Vec<Candidate>>,
usage_metadata: Option<UsageMetadata>,
prompt_feedback: Option<PromptFeedback>,
#[allow(dead_code)]
model_version: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PromptFeedback {
block_reason: Option<String>,
#[allow(dead_code)]
safety_ratings: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct EmbedContentRequest {
content: Content,
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
task_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
output_dimensionality: Option<usize>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct BatchEmbedContentsRequest {
requests: Vec<EmbedContentRequest>,
}
#[derive(Debug, Clone, Deserialize)]
struct EmbeddingValues {
values: Vec<f32>,
}
#[derive(Debug, Clone, Deserialize)]
struct EmbedContentResponse {
embedding: EmbeddingValues,
}
#[derive(Debug, Clone, Deserialize)]
struct BatchEmbedContentsResponse {
embeddings: Vec<EmbeddingValues>,
}
#[derive(Debug, Clone, Serialize)]
struct VertexAIEmbedInstance {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
task_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct VertexAIEmbedParameters {
#[serde(skip_serializing_if = "Option::is_none")]
output_dimensionality: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
auto_truncate: Option<bool>,
}
#[derive(Debug, Clone, Serialize)]
struct VertexAIEmbedRequest {
instances: Vec<VertexAIEmbedInstance>,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<VertexAIEmbedParameters>,
}
#[derive(Debug, Clone, Deserialize)]
struct VertexAIEmbedPrediction {
embeddings: VertexAIEmbeddingResult,
}
#[derive(Debug, Clone, Deserialize)]
struct VertexAIEmbeddingResult {
values: Vec<f32>,
}
#[derive(Debug, Clone, Deserialize)]
struct VertexAIEmbedResponse {
predictions: Vec<VertexAIEmbedPrediction>,
}
#[derive(Debug, Clone, Deserialize)]
struct GeminiErrorResponse {
error: GeminiError,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct GeminiError {
code: i32,
message: String,
status: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GeminiModelsResponse {
#[serde(default)]
pub models: Vec<GeminiModelInfo>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GeminiModelInfo {
pub name: String,
#[serde(default)]
pub display_name: String,
#[serde(default)]
pub description: String,
#[serde(default)]
pub input_token_limit: Option<u32>,
#[serde(default)]
pub output_token_limit: Option<u32>,
#[serde(default)]
pub supported_generation_methods: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(super) enum ThinkingStyle {
None,
Budget { min: i32, max: i32 },
Level,
}
pub(super) struct ModelProfile {
pub prefix: &'static str,
pub context_length: usize,
pub thinking: ThinkingStyle,
pub thinks_by_default: bool,
pub auto_preview_suffix: bool,
pub requires_global_vertex: bool,
}
static MODEL_PROFILES: &[ModelProfile] = &[
ModelProfile {
prefix: "gemini-3.1-",
context_length: 2_000_000,
thinking: ThinkingStyle::Level,
thinks_by_default: true,
auto_preview_suffix: true,
requires_global_vertex: true,
},
ModelProfile {
prefix: "gemini-3-flash",
context_length: 2_000_000,
thinking: ThinkingStyle::Level,
thinks_by_default: true,
auto_preview_suffix: true,
requires_global_vertex: true,
},
ModelProfile {
prefix: "gemini-3-",
context_length: 2_000_000,
thinking: ThinkingStyle::Level,
thinks_by_default: true,
auto_preview_suffix: false,
requires_global_vertex: true,
},
ModelProfile {
prefix: "gemini-2.5-flash-lite",
context_length: 1_048_576,
thinking: ThinkingStyle::Budget {
min: 512,
max: 24_576,
},
thinks_by_default: false,
auto_preview_suffix: false,
requires_global_vertex: false,
},
ModelProfile {
prefix: "gemini-2.5-flash",
context_length: 1_048_576,
thinking: ThinkingStyle::Budget {
min: 0,
max: 24_576,
},
thinks_by_default: true,
auto_preview_suffix: false,
requires_global_vertex: false,
},
ModelProfile {
prefix: "gemini-2.5-pro",
context_length: 1_048_576,
thinking: ThinkingStyle::Budget {
min: 128,
max: 32_768,
},
thinks_by_default: true,
auto_preview_suffix: false,
requires_global_vertex: false,
},
ModelProfile {
prefix: "gemini-2.5-",
context_length: 1_048_576,
thinking: ThinkingStyle::Budget {
min: 0,
max: 24_576,
},
thinks_by_default: true,
auto_preview_suffix: false,
requires_global_vertex: false,
},
ModelProfile {
prefix: "gemini-2.0-",
context_length: 1_048_576,
thinking: ThinkingStyle::None,
thinks_by_default: false,
auto_preview_suffix: false,
requires_global_vertex: false,
},
ModelProfile {
prefix: "gemini-1.5-pro",
context_length: 2_000_000,
thinking: ThinkingStyle::None,
thinks_by_default: false,
auto_preview_suffix: false,
requires_global_vertex: false,
},
ModelProfile {
prefix: "gemini-1.5-flash",
context_length: 1_000_000,
thinking: ThinkingStyle::None,
thinks_by_default: false,
auto_preview_suffix: false,
requires_global_vertex: false,
},
ModelProfile {
prefix: "gemini-1.0-",
context_length: 32_000,
thinking: ThinkingStyle::None,
thinks_by_default: false,
auto_preview_suffix: false,
requires_global_vertex: false,
},
ModelProfile {
prefix: "gemini-pro",
context_length: 32_000,
thinking: ThinkingStyle::None,
thinks_by_default: false,
auto_preview_suffix: false,
requires_global_vertex: false,
},
];
static DEFAULT_PROFILE: ModelProfile = ModelProfile {
prefix: "",
context_length: 1_048_576,
thinking: ThinkingStyle::None,
thinks_by_default: false,
auto_preview_suffix: false,
requires_global_vertex: false,
};
impl GeminiProvider {
fn stream_usage_from_metadata(usage: Option<&UsageMetadata>) -> Option<StreamUsage> {
let usage = usage?;
let mut stream_usage =
StreamUsage::new(usage.prompt_token_count, usage.candidates_token_count);
if usage.cached_content_token_count > 0 {
stream_usage = stream_usage.with_cache_hit_tokens(usage.cached_content_token_count);
}
if usage.thoughts_token_count > 0 {
stream_usage = stream_usage.with_thinking_tokens(usage.thoughts_token_count);
}
Some(stream_usage)
}
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: Client::new(),
endpoint: GeminiEndpoint::GoogleAI {
api_key: api_key.into(),
},
model: DEFAULT_GEMINI_MODEL.to_string(),
embedding_model: DEFAULT_EMBEDDING_MODEL.to_string(),
max_context_length: 1_000_000, embedding_dimension: 3072, cache_ttl: "3600s".to_string(),
cache_state: tokio::sync::RwLock::new(CacheState::default()),
}
}
pub fn from_env() -> Result<Self> {
if let Ok(api_key) = std::env::var("GEMINI_API_KEY") {
return Ok(Self::new(api_key));
}
Self::from_env_vertex_ai()
}
pub fn from_env_vertex_ai() -> Result<Self> {
let project_id = std::env::var("GOOGLE_CLOUD_PROJECT").map_err(|_| {
LlmError::ConfigError(
"VertexAI requires GOOGLE_CLOUD_PROJECT environment variable. \
Run: export GOOGLE_CLOUD_PROJECT=your-project-id"
.to_string(),
)
})?;
let region =
std::env::var("GOOGLE_CLOUD_REGION").unwrap_or_else(|_| "us-central1".to_string());
let access_token = match std::env::var("GOOGLE_ACCESS_TOKEN") {
Ok(token) if !token.is_empty() => token,
_ => Self::get_access_token_from_gcloud()?,
};
Ok(Self::vertex_ai(project_id, region, access_token))
}
fn get_access_token_from_gcloud() -> Result<String> {
debug!("Obtaining access token via gcloud auth print-access-token");
if let Ok(token) = Self::run_gcloud_token_cmd(&["auth", "print-access-token"]) {
return Ok(token);
}
debug!("Falling back to gcloud auth application-default print-access-token");
if let Ok(token) =
Self::run_gcloud_token_cmd(&["auth", "application-default", "print-access-token"])
{
return Ok(token);
}
Err(LlmError::ConfigError(
"Could not obtain a Google Cloud access token. \
Run one of the following and try again:\n \
gcloud auth login\n \
gcloud auth application-default login"
.to_string(),
))
}
fn run_gcloud_token_cmd(args: &[&str]) -> Result<String> {
use std::process::Command;
let output = Command::new("gcloud")
.args(args)
.output()
.map_err(|e| LlmError::ConfigError(format!("Failed to run gcloud: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(LlmError::ConfigError(stderr.trim().to_string()));
}
let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
if token.is_empty() {
return Err(LlmError::ConfigError("empty token".to_string()));
}
Ok(token)
}
pub fn vertex_ai(
project_id: impl Into<String>,
region: impl Into<String>,
access_token: impl Into<String>,
) -> Self {
Self {
client: Client::new(),
endpoint: GeminiEndpoint::VertexAI {
project_id: project_id.into(),
region: region.into(),
access_token: access_token.into(),
},
model: DEFAULT_GEMINI_MODEL.to_string(),
embedding_model: DEFAULT_EMBEDDING_MODEL.to_string(),
max_context_length: 1_000_000,
embedding_dimension: 3072,
cache_ttl: "3600s".to_string(),
cache_state: tokio::sync::RwLock::new(CacheState::default()),
}
}
pub fn with_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
self.cache_ttl = ttl.into();
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
let model_name = model.into();
self.max_context_length = Self::context_length_for_model(&model_name);
self.model = model_name;
self
}
pub fn with_embedding_model(mut self, model: impl Into<String>) -> Self {
let model_name = model.into();
self.embedding_dimension = Self::dimension_for_model(&model_name);
self.embedding_model = model_name;
self
}
pub fn with_embedding_dimension(mut self, dimension: usize) -> Self {
self.embedding_dimension = dimension;
self
}
fn strip_provider_prefix(model: &str) -> &str {
model
.strip_prefix("vertexai:")
.or_else(|| model.strip_prefix("gemini:"))
.or_else(|| model.strip_prefix("google:"))
.unwrap_or(model)
}
fn lookup_profile(model: &str) -> &'static ModelProfile {
let bare = Self::strip_provider_prefix(model);
MODEL_PROFILES
.iter()
.find(|p| bare.starts_with(p.prefix))
.unwrap_or(&DEFAULT_PROFILE)
}
fn profile(&self) -> &'static ModelProfile {
Self::lookup_profile(&self.model)
}
pub fn context_length_for_model(model: &str) -> usize {
Self::lookup_profile(model).context_length
}
pub fn dimension_for_model(model: &str) -> usize {
match model {
m if m.contains("gemini-embedding-2") => 3072,
m if m.contains("gemini-embedding-001") => 3072,
m if m.contains("text-embedding-004") => 768,
m if m.contains("text-embedding-005") => 768,
m if m.contains("text-multilingual-embedding-002") => 768,
_ => 3072,
}
}
fn apply_generation_options(config: &mut GenerationConfig, options: &CompletionOptions) {
if let Some(max_tokens) = options.max_tokens {
config.max_output_tokens = Some(max_tokens);
}
if let Some(temp) = options.temperature {
config.temperature = Some(temp);
}
if let Some(top_p) = options.top_p {
config.top_p = Some(top_p);
}
if let Some(ref stop) = options.stop {
config.stop_sequences = Some(stop.clone());
}
if options.response_format.as_deref() == Some("json_object") {
config.response_mime_type = Some("application/json".to_string());
}
}
pub async fn list_models(&self) -> Result<GeminiModelsResponse> {
let url = match &self.endpoint {
GeminiEndpoint::GoogleAI { api_key } => {
format!("{}/models?key={}", GEMINI_API_BASE, api_key)
}
GeminiEndpoint::VertexAI {
project_id,
region,
access_token: _,
} => {
let host = Self::vertex_host(region);
format!(
"https://{}/v1/projects/{}/locations/{}/publishers/google/models",
host, project_id, region
)
}
};
debug!(url = %url, "Fetching Gemini models list");
let mut req = self.client.get(&url);
if let GeminiEndpoint::VertexAI { access_token, .. } = &self.endpoint {
req = req.bearer_auth(access_token);
}
let response = req
.send()
.await
.map_err(|e| LlmError::NetworkError(format!("Failed to fetch Gemini models: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(LlmError::ApiError(format!(
"Gemini /models returned {}: {}",
status, body
)));
}
response
.json::<GeminiModelsResponse>()
.await
.map_err(|e| LlmError::ProviderError(format!("Failed to parse models response: {}", e)))
}
#[instrument(skip(self, system_instruction))]
async fn ensure_cache(&self, system_instruction: &Content) -> Result<String> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
serde_json::to_string(system_instruction)
.map_err(LlmError::SerializationError)?
.hash(&mut hasher);
let current_hash = hasher.finish();
{
let cache = self.cache_state.read().await;
if let (Some(cache_id), Some(cached_hash)) = (&cache.content_id, cache.system_hash) {
if cached_hash == current_hash {
debug!("Reusing cached content: {}", cache_id);
return Ok(cache_id.clone());
} else {
debug!("System instruction changed, creating new cache");
}
}
}
debug!("Creating cached content (ttl: {})", self.cache_ttl);
let request = CreateCachedContentRequest {
model: format!("models/{}", self.model),
contents: None,
system_instruction: Some(system_instruction.clone()),
tools: None,
ttl: self.cache_ttl.clone(),
};
let url = match &self.endpoint {
GeminiEndpoint::GoogleAI { api_key } => {
format!("{}/cachedContents?key={}", GEMINI_API_BASE, api_key)
}
GeminiEndpoint::VertexAI {
project_id, region, ..
} => {
let effective_region: &str = if self.model.contains("gemini-3") {
"global"
} else {
region.as_str()
};
let host = Self::vertex_host(effective_region);
format!(
"https://{}/v1beta/projects/{}/locations/{}/cachedContents",
host, project_id, effective_region
)
}
};
let mut req = self.client.post(&url).json(&request);
if let GeminiEndpoint::VertexAI { access_token, .. } = &self.endpoint {
req = req.bearer_auth(access_token);
}
let response = req
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::ApiError(format!(
"Failed to create cached content (status {}): {}",
status, error_text
)));
}
let cache_response: CachedContentResponse = response.json().await.map_err(|e| {
LlmError::NetworkError(format!("Failed to parse cache response: {}", e))
})?;
{
let mut cache = self.cache_state.write().await;
cache.content_id = Some(cache_response.name.clone());
cache.system_hash = Some(current_hash);
}
debug!("Created cached content: {}", cache_response.name);
Ok(cache_response.name)
}
fn vertex_host(region: &str) -> String {
if region == "global" {
"aiplatform.googleapis.com".to_string()
} else {
format!("{}-aiplatform.googleapis.com", region)
}
}
fn build_url(&self, model: &str, action: &str) -> String {
let bare = Self::strip_provider_prefix(model);
let profile = Self::lookup_profile(bare);
let model_name = if profile.auto_preview_suffix && !bare.ends_with("-preview") {
format!("{}-preview", bare)
} else {
bare.to_string()
};
match &self.endpoint {
GeminiEndpoint::GoogleAI { api_key } => {
format!(
"{}/models/{}:{}?key={}",
GEMINI_API_BASE, model_name, action, api_key
)
}
GeminiEndpoint::VertexAI {
project_id, region, ..
} => {
let effective_region: &str = if profile.requires_global_vertex {
"global"
} else {
region.as_str()
};
let host = Self::vertex_host(effective_region);
format!(
"https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:{}",
host, project_id, effective_region, model_name, action
)
}
}
}
fn auth_headers(&self) -> Vec<(&'static str, String)> {
match &self.endpoint {
GeminiEndpoint::GoogleAI { .. } => {
vec![]
}
GeminiEndpoint::VertexAI { access_token, .. } => {
vec![("Authorization", format!("Bearer {}", access_token))]
}
}
}
fn convert_messages(messages: &[ChatMessage]) -> (Option<Content>, Vec<Content>) {
let mut system_instruction = None;
let mut contents: Vec<Content> = Vec::new();
let mut i = 0;
while i < messages.len() {
let msg = &messages[i];
match msg.role {
ChatRole::System => {
system_instruction = Some(Content {
parts: vec![Part {
text: Some(msg.content.clone()),
..Default::default()
}],
role: None,
});
i += 1;
}
ChatRole::User => {
if msg.has_images() {
let mut parts = Vec::new();
if !msg.content.is_empty() {
parts.push(Part {
text: Some(msg.content.clone()),
..Default::default()
});
}
if let Some(ref images) = msg.images {
for img in images {
parts.push(Part {
inline_data: Some(Blob {
mime_type: img.mime_type.clone(),
data: img.data.clone(),
}),
..Default::default()
});
}
}
contents.push(Content {
parts,
role: Some("user".to_string()),
});
} else {
contents.push(Content {
parts: vec![Part {
text: Some(msg.content.clone()),
..Default::default()
}],
role: Some("user".to_string()),
});
}
i += 1;
}
ChatRole::Assistant => {
let mut parts = Vec::new();
if !msg.content.is_empty() {
parts.push(Part {
text: Some(msg.content.clone()),
..Default::default()
});
}
if let Some(ref tool_calls) = msg.tool_calls {
for tc in tool_calls {
let args: serde_json::Value =
serde_json::from_str(&tc.function.arguments).unwrap_or_else(|_| {
serde_json::Value::Object(serde_json::Map::new())
});
parts.push(Part {
function_call: Some(FunctionCall {
name: tc.function.name.clone(),
args,
}),
thought_signature: tc.thought_signature.clone(),
..Default::default()
});
}
}
if parts.is_empty() {
parts.push(Part {
text: Some(String::new()),
..Default::default()
});
}
contents.push(Content {
parts,
role: Some("model".to_string()),
});
i += 1;
}
ChatRole::Tool | ChatRole::Function => {
let mut parts = Vec::new();
while i < messages.len()
&& matches!(messages[i].role, ChatRole::Tool | ChatRole::Function)
{
let tool_msg = &messages[i];
let fn_name = tool_msg.name.as_deref().unwrap_or("unknown_function");
let response_value: serde_json::Value =
serde_json::from_str(&tool_msg.content).unwrap_or_else(
|_| serde_json::json!({ "content": tool_msg.content.clone() }),
);
parts.push(Part {
function_response: Some(FunctionResponse {
name: fn_name.to_string(),
response: response_value,
}),
..Default::default()
});
i += 1;
}
if !parts.is_empty() {
contents.push(Content {
parts,
role: Some("user".to_string()),
});
}
}
}
}
(system_instruction, contents)
}
async fn send_request<T: for<'de> Deserialize<'de>>(
&self,
url: &str,
body: &impl Serialize,
) -> Result<T> {
let mut request = self.client.post(url).json(body);
for (key, value) in self.auth_headers() {
request = request.header(key, value);
}
let response = request
.send()
.await
.map_err(|e| LlmError::ApiError(format!("Request failed: {}", e)))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|e| LlmError::ApiError(format!("Failed to read response: {}", e)))?;
if !status.is_success() {
if let Ok(error_response) = serde_json::from_str::<GeminiErrorResponse>(&text) {
return Err(LlmError::ApiError(format!(
"Gemini API error ({}): {}",
error_response.error.code, error_response.error.message
)));
}
return Err(LlmError::ApiError(format!(
"Gemini API error ({}): {}",
status, text
)));
}
serde_json::from_str(&text).map_err(|e| {
LlmError::ApiError(format!("Failed to parse response: {}. Body: {}", e, text))
})
}
fn sanitize_parameters(mut params: serde_json::Value) -> serde_json::Value {
if let Some(obj) = params.as_object_mut() {
obj.remove("$schema");
for (_key, value) in obj.iter_mut() {
if value.is_object() || value.is_array() {
*value = Self::sanitize_parameters(value.clone());
}
}
} else if let Some(arr) = params.as_array_mut() {
for item in arr.iter_mut() {
if item.is_object() || item.is_array() {
*item = Self::sanitize_parameters(item.clone());
}
}
}
params
}
fn convert_tools(tools: &[ToolDefinition]) -> Vec<GeminiTool> {
let declarations: Vec<FunctionDeclaration> = tools
.iter()
.map(|tool| {
let sanitized_params = Self::sanitize_parameters(tool.function.parameters.clone());
FunctionDeclaration {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: Some(sanitized_params),
}
})
.collect();
vec![GeminiTool {
function_declarations: declarations,
}]
}
fn convert_tool_choice(tool_choice: Option<ToolChoice>) -> Option<ToolConfig> {
let mode = match &tool_choice {
None => "AUTO",
Some(ToolChoice::Auto(s)) if s == "auto" => "AUTO",
Some(ToolChoice::Auto(s)) if s == "none" => "NONE",
Some(ToolChoice::Auto(_)) => "AUTO",
Some(ToolChoice::Required(_)) => "ANY",
Some(ToolChoice::Function { function, .. }) => {
return Some(ToolConfig {
function_calling_config: FunctionCallingConfig {
mode: "ANY".to_string(),
allowed_function_names: Some(vec![function.name.clone()]),
},
});
}
};
Some(ToolConfig {
function_calling_config: FunctionCallingConfig {
mode: mode.to_string(),
allowed_function_names: None,
},
})
}
pub fn supports_thinking(&self) -> bool {
!matches!(self.profile().thinking, ThinkingStyle::None)
}
pub fn model_thinks_by_default(&self) -> bool {
self.profile().thinks_by_default
}
pub fn build_thinking_config(
&self,
options: &crate::traits::CompletionOptions,
) -> Option<ThinkingConfig> {
let include_thoughts = options.gemini_include_thoughts;
let budget = options.gemini_thinking_budget;
let level = options.gemini_thinking_level.clone();
if include_thoughts.is_none() && budget.is_none() && level.is_none() {
return None;
}
match self.profile().thinking {
ThinkingStyle::None => None,
ThinkingStyle::Level => Some(ThinkingConfig {
include_thoughts,
thinking_level: level
.or_else(|| include_thoughts.filter(|&v| v).map(|_| "high".to_string())),
thinking_budget: None,
}),
ThinkingStyle::Budget { .. } => Some(ThinkingConfig {
include_thoughts,
thinking_level: None,
thinking_budget: budget.or_else(|| include_thoughts.filter(|&v| v).map(|_| -1i32)),
}),
}
}
#[cfg(test)]
pub fn default_thinking_config(model: &str) -> ThinkingConfig {
let profile = Self::lookup_profile(model);
match profile.thinking {
ThinkingStyle::Level => ThinkingConfig {
include_thoughts: Some(true),
thinking_level: Some("high".to_string()),
thinking_budget: None,
},
_ => ThinkingConfig {
include_thoughts: Some(true),
thinking_level: None,
thinking_budget: Some(-1),
},
}
}
}
#[async_trait]
impl LLMProvider for GeminiProvider {
fn name(&self) -> &str {
match &self.endpoint {
GeminiEndpoint::GoogleAI { .. } => "gemini",
GeminiEndpoint::VertexAI { .. } => "vertex-ai",
}
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
self.max_context_length
}
#[instrument(skip(self, prompt), fields(model = %self.model))]
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
#[instrument(skip(self, prompt, options), fields(model = %self.model))]
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let mut messages = Vec::new();
if let Some(system) = &options.system_prompt {
messages.push(ChatMessage::system(system));
}
messages.push(ChatMessage::user(prompt));
self.chat(&messages, Some(options)).await
}
#[instrument(skip(self, messages, options), fields(model = %self.model))]
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let (system_instruction, contents) = Self::convert_messages(messages);
if contents.is_empty() {
return Err(LlmError::InvalidRequest(
"No user messages provided".to_string(),
));
}
let options = options.cloned().unwrap_or_default();
let mut generation_config = GenerationConfig::default();
Self::apply_generation_options(&mut generation_config, &options);
if let Some(thinking_cfg) = self.build_thinking_config(&options) {
generation_config.thinking_config = Some(thinking_cfg);
}
let cached_content = if let Some(system_inst) = system_instruction.as_ref() {
match self.ensure_cache(system_inst).await {
Ok(cache_id) => Some(cache_id),
Err(e) => {
debug!(
"Failed to create/reuse cache: {}, continuing without cache",
e
);
None
}
}
} else {
None
};
let request = GenerateContentRequest {
contents,
generation_config: Some(generation_config),
system_instruction: if cached_content.is_none() {
system_instruction
} else {
None
},
safety_settings: None,
cached_content,
tools: None,
tool_config: None,
};
let url = self.build_url(&self.model, "generateContent");
debug!("Sending request to Gemini: {}", url);
let response: GenerateContentResponse = self.send_request(&url, &request).await?;
let candidates = match response.candidates {
Some(c) if !c.is_empty() => c,
_ => {
let block_reason = response
.prompt_feedback
.as_ref()
.and_then(|pf| pf.block_reason.as_deref())
.unwrap_or("unknown");
return Err(LlmError::ApiError(format!(
"Gemini blocked the request (promptFeedback.blockReason: {}). \
Adjust your prompt or safety settings.",
block_reason
)));
}
};
let candidate = candidates
.first()
.ok_or_else(|| LlmError::ApiError("Empty candidates array".to_string()))?;
let mut content = String::new();
let mut thinking_content_parts: Vec<String> = Vec::new();
for part in &candidate.content.parts {
if let Some(text) = &part.text {
if part.thought == Some(true) {
thinking_content_parts.push(text.clone());
} else {
content.push_str(text);
}
}
}
let thinking_content = if thinking_content_parts.is_empty() {
None
} else {
Some(thinking_content_parts.join(""))
};
let usage = response.usage_metadata.unwrap_or_default();
let mut metadata = HashMap::new();
if !candidate.safety_ratings.is_empty() {
metadata.insert(
"safety_ratings".to_string(),
serde_json::json!(candidate.safety_ratings),
);
}
Ok(LLMResponse {
content,
prompt_tokens: usage.prompt_token_count,
completion_tokens: usage.candidates_token_count,
total_tokens: usage.total_token_count,
model: self.model.clone(),
finish_reason: candidate.finish_reason.clone(),
tool_calls: Vec::new(),
metadata,
cache_hit_tokens: if usage.cached_content_token_count > 0 {
Some(usage.cached_content_token_count)
} else {
None
},
thinking_tokens: if usage.thoughts_token_count > 0 {
Some(usage.thoughts_token_count)
} else {
None
},
thinking_content,
})
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let (system_instruction, contents) = Self::convert_messages(messages);
if contents.is_empty() {
return Err(LlmError::InvalidRequest(
"No user messages provided".to_string(),
));
}
let options = options.cloned().unwrap_or_default();
let mut generation_config = GenerationConfig::default();
Self::apply_generation_options(&mut generation_config, &options);
if let Some(thinking_cfg) = self.build_thinking_config(&options) {
generation_config.thinking_config = Some(thinking_cfg);
}
let gemini_tools = if tools.is_empty() {
None
} else {
Some(Self::convert_tools(tools))
};
let gemini_tool_config = Self::convert_tool_choice(tool_choice);
let request = GenerateContentRequest {
contents,
generation_config: Some(generation_config),
system_instruction,
safety_settings: None,
cached_content: None,
tools: gemini_tools,
tool_config: gemini_tool_config,
};
let url = self.build_url(&self.model, "generateContent");
debug!("Sending chat_with_tools request to Gemini: {}", url);
let response: GenerateContentResponse = self.send_request(&url, &request).await?;
let candidates = match response.candidates {
Some(c) if !c.is_empty() => c,
_ => {
let block_reason = response
.prompt_feedback
.as_ref()
.and_then(|pf| pf.block_reason.as_deref())
.unwrap_or("unknown");
return Err(LlmError::ApiError(format!(
"Gemini blocked the request (promptFeedback.blockReason: {}). \
Adjust your prompt or safety settings.",
block_reason
)));
}
};
let candidate = candidates
.first()
.ok_or_else(|| LlmError::ApiError("Empty candidates array".to_string()))?;
let mut content = String::new();
let mut tool_calls = Vec::new();
let mut thinking_content_parts: Vec<String> = Vec::new();
let mut pending_sig: Option<String> = None;
for part in &candidate.content.parts {
if let Some(text) = &part.text {
if part.thought == Some(true) {
thinking_content_parts.push(text.clone());
} else {
content.push_str(text);
}
}
if let Some(fc) = &part.function_call {
let sig = part
.thought_signature
.clone()
.or_else(|| pending_sig.take());
tool_calls.push(ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4().to_string().replace('-', "")),
call_type: "function".to_string(),
function: crate::traits::FunctionCall {
name: fc.name.clone(),
arguments: fc.args.to_string(),
},
thought_signature: sig,
});
} else if part.thought_signature.is_some() {
pending_sig = part.thought_signature.clone();
}
}
let thinking_content = if thinking_content_parts.is_empty() {
None
} else {
Some(thinking_content_parts.join(""))
};
let usage = response.usage_metadata.unwrap_or_default();
let mut metadata = HashMap::new();
if !candidate.safety_ratings.is_empty() {
metadata.insert(
"safety_ratings".to_string(),
serde_json::json!(candidate.safety_ratings),
);
}
Ok(LLMResponse {
content,
prompt_tokens: usage.prompt_token_count,
completion_tokens: usage.candidates_token_count,
total_tokens: usage.total_token_count,
model: self.model.clone(),
finish_reason: candidate.finish_reason.clone(),
tool_calls,
metadata,
cache_hit_tokens: if usage.cached_content_token_count > 0 {
Some(usage.cached_content_token_count)
} else {
None
},
thinking_tokens: if usage.thoughts_token_count > 0 {
Some(usage.thoughts_token_count)
} else {
None
},
thinking_content,
})
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
use futures::StreamExt;
let messages = vec![ChatMessage::user(prompt)];
let (system_instruction, contents) = Self::convert_messages(&messages);
let generation_config = None::<GenerationConfig>;
let request = GenerateContentRequest {
contents,
generation_config,
system_instruction,
safety_settings: None,
cached_content: None,
tools: None,
tool_config: None,
};
let base_url = self.build_url(&self.model, "streamGenerateContent");
let url = if base_url.contains('?') {
format!("{}&alt=sse", base_url)
} else {
format!("{}?alt=sse", base_url)
};
let mut req = self.client.post(&url).json(&request);
for (key, value) in self.auth_headers() {
req = req.header(key, value);
}
let response = req
.send()
.await
.map_err(|e| LlmError::ApiError(format!("Stream request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
if status.as_u16() == 429 || text.contains("RESOURCE_EXHAUSTED") {
return Err(LlmError::RateLimited(format!("Stream error: {}", text)));
}
return Err(LlmError::ApiError(format!("Stream error: {}", text)));
}
let stream = response.bytes_stream();
let mapped_stream = stream.map(|result| {
match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let mut content_parts = Vec::new();
for line in text.lines() {
if let Some(json_str) = line.strip_prefix("data: ") {
if let Ok(chunk) =
serde_json::from_str::<GenerateContentResponse>(json_str)
{
if let Some(candidates) = chunk.candidates {
if let Some(candidate) = candidates.first() {
let content: String = candidate
.content
.parts
.iter()
.filter_map(|p| p.text.clone())
.collect();
if !content.is_empty() {
content_parts.push(content);
}
}
}
}
}
}
Ok(content_parts.join(""))
}
Err(e) => Err(LlmError::ApiError(format!("Stream error: {}", e))),
}
});
Ok(mapped_stream.boxed())
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_json_mode(&self) -> bool {
self.model.contains("gemini-1.5")
|| self.model.contains("gemini-2")
|| self.model.contains("gemini-3")
}
fn supports_function_calling(&self) -> bool {
self.model.contains("gemini-1.5")
|| self.model.contains("gemini-2")
|| self.model.contains("gemini-3")
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
use futures::StreamExt;
let (system_instruction, contents) = Self::convert_messages(messages);
if contents.is_empty() {
return Err(LlmError::InvalidRequest(
"No user messages provided".to_string(),
));
}
let gemini_tools = if !tools.is_empty() {
Some(Self::convert_tools(tools))
} else {
None
};
let tool_config = Self::convert_tool_choice(tool_choice);
let options = options.cloned().unwrap_or_default();
let mut generation_config = GenerationConfig::default();
Self::apply_generation_options(&mut generation_config, &options);
if let Some(thinking_cfg) = self.build_thinking_config(&options) {
generation_config.thinking_config = Some(thinking_cfg);
}
let request = GenerateContentRequest {
contents,
generation_config: Some(generation_config),
system_instruction,
safety_settings: None,
cached_content: None,
tools: gemini_tools,
tool_config,
};
let base_url = self.build_url(&self.model, "streamGenerateContent");
let url = if base_url.contains('?') {
format!("{}&alt=sse", base_url)
} else {
format!("{}?alt=sse", base_url)
};
let mut req = self.client.post(&url).json(&request);
for (key, value) in self.auth_headers() {
req = req.header(key, value);
}
let response = req
.send()
.await
.map_err(|e| LlmError::ApiError(format!("Stream request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
if status.as_u16() == 429 || text.contains("RESOURCE_EXHAUSTED") {
return Err(LlmError::RateLimited(format!("Stream error: {}", text)));
}
return Err(LlmError::ApiError(format!("Stream error: {}", text)));
}
let stream = response.bytes_stream();
let fn_call_index = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let pending_sig: std::sync::Arc<std::sync::Mutex<Option<String>>> =
std::sync::Arc::new(std::sync::Mutex::new(None));
let latest_usage: std::sync::Arc<std::sync::Mutex<Option<StreamUsage>>> =
std::sync::Arc::new(std::sync::Mutex::new(None));
let mapped_stream = stream.flat_map(move |result| {
let fn_call_index = fn_call_index.clone();
let pending_sig = pending_sig.clone();
let latest_usage = latest_usage.clone();
let items: Vec<crate::error::Result<StreamChunk>> = match result {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
let mut chunks: Vec<crate::error::Result<StreamChunk>> = Vec::new();
for line in text.lines() {
if let Some(json_str) = line.strip_prefix("data: ") {
if let Ok(sse_response) =
serde_json::from_str::<GenerateContentResponse>(json_str)
{
let stream_usage = Self::stream_usage_from_metadata(
sse_response.usage_metadata.as_ref(),
);
if let Some(ref usage) = stream_usage {
if let Ok(mut latest) = latest_usage.lock() {
*latest = Some(usage.clone());
}
}
if let Some(candidates) = sse_response.candidates {
if let Some(candidate) = candidates.first() {
for part in &candidate.content.parts {
if let Some(text_content) = &part.text {
if !text_content.is_empty() {
if part.thought == Some(true) {
chunks.push(Ok(
StreamChunk::ThinkingContent {
text: text_content.clone(),
tokens_used: None,
budget_total: None,
},
));
} else {
chunks.push(Ok(StreamChunk::Content(
text_content.clone(),
)));
}
}
}
if let Some(func_call) = &part.function_call {
let args_json =
serde_json::to_string(&func_call.args).ok();
let idx = fn_call_index.fetch_add(
1,
std::sync::atomic::Ordering::Relaxed,
);
let sig =
part.thought_signature.clone().or_else(|| {
pending_sig
.lock()
.ok()
.and_then(|mut s| s.take())
});
chunks.push(Ok(StreamChunk::ToolCallDelta {
index: idx,
id: Some(uuid::Uuid::new_v4().to_string()),
function_name: Some(func_call.name.clone()),
function_arguments: args_json,
thought_signature: sig,
}));
} else if part.thought_signature.is_some() {
if let Ok(mut ps) = pending_sig.lock() {
*ps = part.thought_signature.clone();
}
}
}
if let Some(ref reason) = candidate.finish_reason {
let mapped_reason = match reason.as_str() {
"STOP" => "stop",
"MAX_TOKENS" => "length",
"SAFETY" => "content_filter",
_ => reason.as_str(),
};
let usage = stream_usage.or_else(|| {
latest_usage
.lock()
.ok()
.and_then(|latest| latest.clone())
});
chunks.push(Ok(StreamChunk::Finished {
reason: mapped_reason.to_string(),
ttft_ms: None,
usage,
}));
}
}
}
}
}
}
chunks
}
Err(e) => vec![Err(LlmError::ApiError(format!("Stream error: {}", e)))],
};
futures::stream::iter(items)
});
Ok(mapped_stream.boxed())
}
fn supports_tool_streaming(&self) -> bool {
self.model.contains("gemini-1.5")
|| self.model.contains("gemini-2")
|| self.model.contains("gemini-3")
}
}
#[async_trait]
impl EmbeddingProvider for GeminiProvider {
fn name(&self) -> &str {
"gemini"
}
#[allow(clippy::misnamed_getters)]
fn model(&self) -> &str {
&self.embedding_model
}
fn dimension(&self) -> usize {
self.embedding_dimension
}
fn max_tokens(&self) -> usize {
2048 }
#[instrument(skip(self, texts), fields(model = %self.embedding_model, count = texts.len()))]
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if matches!(&self.endpoint, GeminiEndpoint::VertexAI { .. }) {
return self.embed_vertex_ai(texts).await;
}
if texts.len() > 1 {
return self.embed_batch(texts).await;
}
let output_dim =
if self.embedding_dimension != Self::dimension_for_model(&self.embedding_model) {
Some(self.embedding_dimension)
} else {
None
};
let request = EmbedContentRequest {
content: Content {
parts: vec![Part {
text: Some(texts[0].clone()),
..Default::default()
}],
role: None,
},
model: None, task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
title: None,
output_dimensionality: output_dim,
};
let url = self.build_url(&self.embedding_model, "embedContent");
debug!("Sending embedding request to Gemini: {}", url);
let response: EmbedContentResponse = self.send_request(&url, &request).await?;
Ok(vec![response.embedding.values])
}
async fn embed_one(&self, text: &str) -> Result<Vec<f32>> {
let results = self.embed(&[text.to_string()]).await?;
results
.into_iter()
.next()
.ok_or_else(|| LlmError::Unknown("Empty embedding result".to_string()))
}
}
impl GeminiProvider {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let model_path = format!("models/{}", self.embedding_model);
let output_dim =
if self.embedding_dimension != Self::dimension_for_model(&self.embedding_model) {
Some(self.embedding_dimension)
} else {
None
};
let requests: Vec<EmbedContentRequest> = texts
.iter()
.map(|text| EmbedContentRequest {
content: Content {
parts: vec![Part {
text: Some(text.clone()),
..Default::default()
}],
role: None,
},
model: Some(model_path.clone()),
task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
title: None,
output_dimensionality: output_dim,
})
.collect();
let batch_request = BatchEmbedContentsRequest { requests };
let url = self.build_url(&self.embedding_model, "batchEmbedContents");
debug!("Sending batch embedding request to Gemini: {}", url);
let response: BatchEmbedContentsResponse = self.send_request(&url, &batch_request).await?;
Ok(response.embeddings.into_iter().map(|e| e.values).collect())
}
async fn embed_vertex_ai(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let instances: Vec<VertexAIEmbedInstance> = texts
.iter()
.map(|text| VertexAIEmbedInstance {
content: text.clone(),
task_type: Some("RETRIEVAL_DOCUMENT".to_string()),
title: None,
})
.collect();
let request = VertexAIEmbedRequest {
instances,
parameters: Some(VertexAIEmbedParameters {
output_dimensionality: Some(self.embedding_dimension),
auto_truncate: Some(true),
}),
};
let url = self.build_url(&self.embedding_model, "predict");
debug!("Sending VertexAI embedding request: {}", url);
let response: VertexAIEmbedResponse = self.send_request(&url, &request).await?;
Ok(response
.predictions
.into_iter()
.map(|p| p.embeddings.values)
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_length_detection() {
assert_eq!(
GeminiProvider::context_length_for_model("gemini-2.0-flash"),
1_048_576
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-1.5-pro"),
2_000_000
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-1.0-pro"),
32_000
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-2.5-flash-lite"),
1_048_576
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-2.5-flash"),
1_048_576
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-3.1-pro-preview"),
2_000_000
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-3-flash"),
2_000_000
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-1.5-flash"),
1_000_000
);
}
#[test]
fn test_embedding_dimension_detection() {
assert_eq!(
GeminiProvider::dimension_for_model("gemini-embedding-001"),
3072
);
assert_eq!(
GeminiProvider::dimension_for_model("text-embedding-004"),
768
);
assert_eq!(
GeminiProvider::dimension_for_model("text-embedding-005"),
768
);
}
#[test]
fn test_provider_builder() {
let provider = GeminiProvider::new("test-key")
.with_model("gemini-1.5-pro")
.with_embedding_model("text-embedding-004");
assert_eq!(LLMProvider::model(&provider), "gemini-1.5-pro");
assert_eq!(provider.dimension(), 768);
assert_eq!(provider.max_context_length(), 2_000_000);
}
#[test]
fn test_provider_builder_default_embedding() {
let provider = GeminiProvider::new("test-key");
assert_eq!(EmbeddingProvider::model(&provider), "gemini-embedding-001");
assert_eq!(provider.dimension(), 3072);
}
#[test]
fn test_vertex_ai_provider() {
let provider = GeminiProvider::vertex_ai("my-project", "us-central1", "test-token");
assert_eq!(LLMProvider::name(&provider), "vertex-ai");
}
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
];
let (system, contents) = GeminiProvider::convert_messages(&messages);
assert!(system.is_some());
assert_eq!(
system.unwrap().parts[0].text,
Some("You are helpful".to_string())
);
assert_eq!(contents.len(), 2);
assert_eq!(contents[0].role.as_deref(), Some("user"));
assert_eq!(contents[1].role.as_deref(), Some("model"));
}
#[test]
fn test_convert_messages_text_only() {
let messages = vec![ChatMessage::user("Hello, world!")];
let (_, contents) = GeminiProvider::convert_messages(&messages);
assert_eq!(contents.len(), 1);
let json = serde_json::to_value(&contents[0]).unwrap();
let parts = &json["parts"];
assert!(parts.is_array());
assert_eq!(parts.as_array().unwrap().len(), 1);
assert_eq!(parts[0]["text"], "Hello, world!");
assert!(parts[0].get("inlineData").is_none());
}
#[test]
fn test_convert_messages_with_images() {
use crate::traits::ImageData;
let images = vec![ImageData::new("base64data", "image/png")];
let messages = vec![ChatMessage::user_with_images("What's this?", images)];
let (_, contents) = GeminiProvider::convert_messages(&messages);
assert_eq!(contents.len(), 1);
let json = serde_json::to_value(&contents[0]).unwrap();
let parts = &json["parts"];
assert!(parts.is_array());
assert_eq!(parts.as_array().unwrap().len(), 2);
assert_eq!(parts[0]["text"], "What's this?");
assert!(
parts[1].get("inlineData").is_some(),
"Should have inlineData for image"
);
assert_eq!(parts[1]["inlineData"]["mimeType"], "image/png");
assert_eq!(parts[1]["inlineData"]["data"], "base64data");
}
#[test]
fn test_convert_messages_multiple_images() {
use crate::traits::ImageData;
let images = vec![
ImageData::new("img1data", "image/png"),
ImageData::new("img2data", "image/jpeg"),
];
let messages = vec![ChatMessage::user_with_images("Compare these", images)];
let (_, contents) = GeminiProvider::convert_messages(&messages);
let json = serde_json::to_value(&contents[0]).unwrap();
let parts = &json["parts"];
assert_eq!(parts.as_array().unwrap().len(), 3);
assert_eq!(parts[1]["inlineData"]["mimeType"], "image/png");
assert_eq!(parts[2]["inlineData"]["mimeType"], "image/jpeg");
}
#[test]
fn test_build_url_google_ai() {
let provider = GeminiProvider::new("test-api-key");
let url = provider.build_url("gemini-2.0-flash", "generateContent");
assert!(url.contains("generativelanguage.googleapis.com"));
assert!(url.contains("gemini-2.0-flash"));
assert!(url.contains("key=test-api-key"));
}
#[test]
fn test_build_url_vertex_ai() {
let provider = GeminiProvider::vertex_ai("my-project", "us-central1", "token");
let url = provider.build_url("gemini-2.0-flash", "generateContent");
assert!(url.contains("aiplatform.googleapis.com"));
assert!(url.contains("my-project"));
assert!(url.contains("us-central1"));
}
#[test]
fn test_build_url_vertex_ai_predict() {
let provider = GeminiProvider::vertex_ai("my-project", "us-central1", "token");
let url = provider.build_url("gemini-embedding-001", "predict");
assert!(url.contains("aiplatform.googleapis.com"));
assert!(url.contains("gemini-embedding-001"));
assert!(url.contains(":predict"));
assert!(url.contains("my-project"));
assert!(url.contains("us-central1"));
}
#[test]
fn test_build_url_vertex_ai_global_region() {
let provider = GeminiProvider::vertex_ai("my-project", "global", "token");
let url = provider.build_url("gemini-3-flash-preview", "generateContent");
assert!(
url.contains("https://aiplatform.googleapis.com"),
"global region must use aiplatform.googleapis.com (no prefix), got: {}",
url
);
assert!(
!url.contains("global-aiplatform"),
"must NOT contain 'global-aiplatform', got: {}",
url
);
assert!(
url.contains("/locations/global/"),
"must contain /locations/global/, got: {}",
url
);
assert!(url.contains("gemini-3-flash-preview"), "got: {}", url);
}
#[test]
fn test_vertex_host_regional() {
assert_eq!(
GeminiProvider::vertex_host("us-central1"),
"us-central1-aiplatform.googleapis.com"
);
assert_eq!(
GeminiProvider::vertex_host("europe-west4"),
"europe-west4-aiplatform.googleapis.com"
);
}
#[test]
fn test_vertex_host_global() {
assert_eq!(
GeminiProvider::vertex_host("global"),
"aiplatform.googleapis.com"
);
}
#[test]
fn test_supports_thinking_gemini_25() {
let provider = GeminiProvider::new("key").with_model("gemini-2.5-flash");
assert!(provider.supports_thinking());
let provider = GeminiProvider::new("key").with_model("gemini-2.5-pro");
assert!(provider.supports_thinking());
let provider =
GeminiProvider::vertex_ai("proj", "us-central1", "tok").with_model("gemini-2.5-flash");
assert!(provider.supports_thinking());
let provider =
GeminiProvider::vertex_ai("proj", "us-central1", "tok").with_model("gemini-2.5-pro");
assert!(provider.supports_thinking());
}
#[test]
fn test_supports_thinking_gemini_3() {
let provider = GeminiProvider::new("key").with_model("gemini-3-flash");
assert!(provider.supports_thinking());
let provider = GeminiProvider::new("key").with_model("gemini-3-pro");
assert!(provider.supports_thinking());
let provider =
GeminiProvider::vertex_ai("proj", "us-central1", "tok").with_model("gemini-3-flash");
assert!(provider.supports_thinking());
let provider = GeminiProvider::new("key").with_model("gemini-3.1-pro-preview");
assert!(provider.supports_thinking());
}
#[test]
fn test_supports_thinking_gemini_1x() {
let provider = GeminiProvider::new("key").with_model("gemini-1.5-flash");
assert!(!provider.supports_thinking());
let provider =
GeminiProvider::vertex_ai("proj", "us-central1", "tok").with_model("gemini-1.5-flash");
assert!(!provider.supports_thinking());
let provider = GeminiProvider::new("key").with_model("gemini-1.0-pro");
assert!(!provider.supports_thinking());
let provider = GeminiProvider::new("key").with_model("gemini-2.0-flash");
assert!(!provider.supports_thinking());
}
#[test]
fn test_default_thinking_config_gemini3() {
let cfg = GeminiProvider::default_thinking_config("gemini-3-flash-preview");
assert!(
cfg.thinking_level.is_some(),
"Gemini 3 requires thinking_level to be set"
);
assert_eq!(cfg.thinking_level.as_deref(), Some("high"));
assert!(cfg.thinking_budget.is_none());
let cfg = GeminiProvider::default_thinking_config("gemini-3.1-pro-preview");
assert_eq!(cfg.thinking_level.as_deref(), Some("high"));
}
#[test]
fn test_default_thinking_config_gemini25() {
let cfg = GeminiProvider::default_thinking_config("gemini-2.5-flash");
assert!(
cfg.thinking_budget.is_some(),
"Gemini 2.5 requires thinking_budget to be set"
);
assert_eq!(cfg.thinking_budget, Some(-1));
assert!(cfg.thinking_level.is_none());
let cfg = GeminiProvider::default_thinking_config("gemini-2.5-pro");
assert_eq!(cfg.thinking_budget, Some(-1));
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
include_thoughts: Some(true),
thinking_level: Some("high".to_string()),
thinking_budget: Some(1024),
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["includeThoughts"], true);
assert_eq!(json["thinkingLevel"], "high");
assert_eq!(json["thinkingBudget"], 1024);
}
#[test]
fn test_part_thought_deserialization() {
let json = r#"{"text": "thinking...", "thought": true}"#;
let part: Part = serde_json::from_str(json).unwrap();
assert_eq!(part.text, Some("thinking...".to_string()));
assert_eq!(part.thought, Some(true));
}
#[test]
fn test_part_thought_defaults_to_none() {
let json = r#"{"text": "response"}"#;
let part: Part = serde_json::from_str(json).unwrap();
assert_eq!(part.text, Some("response".to_string()));
assert_eq!(part.thought, None);
}
#[test]
fn test_usage_metadata_thoughts_token_count() {
let json = r#"{"promptTokenCount": 100, "candidatesTokenCount": 50, "totalTokenCount": 150, "thoughtsTokenCount": 25}"#;
let usage: UsageMetadata = serde_json::from_str(json).unwrap();
assert_eq!(usage.prompt_token_count, 100);
assert_eq!(usage.candidates_token_count, 50);
assert_eq!(usage.thoughts_token_count, 25);
}
#[test]
fn test_constants() {
assert_eq!(
GEMINI_API_BASE,
"https://generativelanguage.googleapis.com/v1beta"
);
assert_eq!(DEFAULT_GEMINI_MODEL, "gemini-2.5-flash");
assert_eq!(DEFAULT_EMBEDDING_MODEL, "gemini-embedding-001");
}
#[test]
fn test_google_ai_provider_name() {
let provider = GeminiProvider::new("test-key");
assert_eq!(LLMProvider::name(&provider), "gemini");
}
#[test]
fn test_supports_streaming() {
let provider = GeminiProvider::new("test-key");
assert!(provider.supports_streaming());
}
#[test]
fn test_supports_json_mode_gemini_25() {
let provider = GeminiProvider::new("key").with_model("gemini-2.5-flash");
assert!(provider.supports_json_mode());
}
#[test]
fn test_supports_json_mode_gemini_15() {
let provider = GeminiProvider::new("key").with_model("gemini-1.5-pro");
assert!(provider.supports_json_mode());
}
#[test]
fn test_supports_json_mode_gemini_3() {
let provider = GeminiProvider::new("key").with_model("gemini-3-flash");
assert!(provider.supports_json_mode());
let provider = GeminiProvider::new("key").with_model("gemini-3-pro");
assert!(provider.supports_json_mode());
}
#[test]
fn test_supports_json_mode_gemini_10() {
let provider = GeminiProvider::new("key").with_model("gemini-1.0-pro");
assert!(!provider.supports_json_mode());
}
#[test]
fn test_with_cache_ttl() {
let provider = GeminiProvider::new("key").with_cache_ttl("7200s");
assert_eq!(provider.cache_ttl, "7200s");
}
#[test]
fn test_embedding_provider_name() {
let provider = GeminiProvider::new("key");
assert_eq!(EmbeddingProvider::name(&provider), "gemini");
}
#[test]
fn test_embedding_provider_model() {
let provider = GeminiProvider::new("key").with_embedding_model("text-embedding-005");
assert_eq!(EmbeddingProvider::model(&provider), "text-embedding-005");
}
#[test]
fn test_embedding_provider_max_tokens() {
let provider = GeminiProvider::new("key");
assert!(EmbeddingProvider::max_tokens(&provider) > 0);
}
#[tokio::test]
async fn test_embed_empty_input() {
let provider = GeminiProvider::new("key");
let texts: Vec<String> = vec![];
let result = provider.embed_batch(&texts).await;
assert!(result.is_err());
}
#[test]
fn test_generation_config_serialization() {
let config = GenerationConfig {
max_output_tokens: Some(1000),
temperature: Some(0.7),
top_p: Some(0.9),
top_k: Some(40),
stop_sequences: Some(vec!["END".to_string()]),
response_mime_type: Some("application/json".to_string()),
thinking_config: None,
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["maxOutputTokens"], 1000);
let temp = json["temperature"].as_f64().unwrap();
assert!((temp - 0.7).abs() < 0.001);
let top_p = json["topP"].as_f64().unwrap();
assert!((top_p - 0.9).abs() < 0.001);
assert_eq!(json["topK"], 40);
assert_eq!(json["stopSequences"], serde_json::json!(["END"]));
assert_eq!(json["responseMimeType"], "application/json");
}
#[test]
fn test_gemini_models_response_deserialization() {
let json = r#"{
"models": [
{
"name": "models/gemini-2.5-flash",
"displayName": "Gemini 2.5 Flash",
"description": "Fast model",
"inputTokenLimit": 1000000,
"outputTokenLimit": 8192,
"supportedGenerationMethods": ["generateContent"]
}
]
}"#;
let response: GeminiModelsResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.models.len(), 1);
assert_eq!(response.models[0].name, "models/gemini-2.5-flash");
assert_eq!(response.models[0].display_name, "Gemini 2.5 Flash");
assert_eq!(response.models[0].input_token_limit, Some(1000000));
}
#[test]
fn test_function_call_deserialization() {
let json = r#"{"name": "get_weather", "args": {"location": "London"}}"#;
let fc: FunctionCall = serde_json::from_str(json).unwrap();
assert_eq!(fc.name, "get_weather");
assert_eq!(fc.args["location"], "London");
}
#[test]
fn test_convert_messages_assistant_with_tool_calls() {
use crate::traits::{FunctionCall as LlmFunctionCall, ToolCall};
let tc1 = ToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: LlmFunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location":"London"}"#.to_string(),
},
thought_signature: None,
};
let tc2 = ToolCall {
id: "call_2".to_string(),
call_type: "function".to_string(),
function: LlmFunctionCall {
name: "get_time".to_string(),
arguments: r#"{"timezone":"UTC"}"#.to_string(),
},
thought_signature: None,
};
let msg = ChatMessage::assistant_with_tools("", vec![tc1, tc2]);
let (_, contents) = GeminiProvider::convert_messages(&[msg]);
assert_eq!(contents.len(), 1);
assert_eq!(contents[0].role.as_deref(), Some("model"));
let parts = &contents[0].parts;
assert_eq!(
parts.len(),
2,
"expected 2 functionCall parts, got {}",
parts.len()
);
let fc1 = parts[0]
.function_call
.as_ref()
.expect("part 0 should be functionCall");
assert_eq!(fc1.name, "get_weather");
assert_eq!(fc1.args["location"], "London");
let fc2 = parts[1]
.function_call
.as_ref()
.expect("part 1 should be functionCall");
assert_eq!(fc2.name, "get_time");
assert_eq!(fc2.args["timezone"], "UTC");
}
#[test]
fn test_convert_messages_tool_result() {
let mut tool_msg = ChatMessage::tool_result("call_1", r#"{"temp":20}"#);
tool_msg.name = Some("get_weather".to_string());
let (_, contents) = GeminiProvider::convert_messages(&[tool_msg]);
assert_eq!(contents.len(), 1);
assert_eq!(contents[0].role.as_deref(), Some("user"));
let part = &contents[0].parts[0];
let fr = part
.function_response
.as_ref()
.expect("part should be functionResponse");
assert_eq!(fr.name, "get_weather");
assert_eq!(fr.response["temp"], 20);
}
#[test]
fn test_convert_messages_tool_result_plain_text() {
let mut tool_msg = ChatMessage::tool_result("call_1", "done");
tool_msg.name = Some("run_command".to_string());
let (_, contents) = GeminiProvider::convert_messages(&[tool_msg]);
let fr = contents[0].parts[0]
.function_response
.as_ref()
.expect("should be functionResponse");
assert_eq!(fr.name, "run_command");
assert_eq!(fr.response["content"], "done");
}
#[test]
fn test_convert_messages_parallel_tool_results_grouped() {
let mut tr1 = ChatMessage::tool_result("call_1", r#"{"temp":20}"#);
tr1.name = Some("get_weather".to_string());
let mut tr2 = ChatMessage::tool_result("call_2", r#"{"time":"12:00"}"#);
tr2.name = Some("get_time".to_string());
let (_, contents) = GeminiProvider::convert_messages(&[tr1, tr2]);
assert_eq!(
contents.len(),
1,
"parallel tool results must be grouped into a single user Content"
);
assert_eq!(contents[0].role.as_deref(), Some("user"));
assert_eq!(
contents[0].parts.len(),
2,
"expected one Part per tool result"
);
let fr1 = contents[0].parts[0].function_response.as_ref().unwrap();
let fr2 = contents[0].parts[1].function_response.as_ref().unwrap();
assert_eq!(fr1.name, "get_weather");
assert_eq!(fr2.name, "get_time");
}
#[test]
fn test_convert_messages_full_tool_calling_turn() {
use crate::traits::{FunctionCall as LlmFunctionCall, ToolCall};
let user_msg = ChatMessage::user("What's the weather in London?");
let tc = ToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: LlmFunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location":"London"}"#.to_string(),
},
thought_signature: None,
};
let assistant_msg = ChatMessage::assistant_with_tools("", vec![tc]);
let mut tool_result =
ChatMessage::tool_result("call_1", r#"{"temp":18,"condition":"cloudy"}"#);
tool_result.name = Some("get_weather".to_string());
let final_answer = ChatMessage::assistant("It is 18°C and cloudy in London.");
let messages = vec![user_msg, assistant_msg, tool_result, final_answer];
let (_, contents) = GeminiProvider::convert_messages(&messages);
assert_eq!(contents.len(), 4, "expected 4 Content blocks");
assert_eq!(contents[0].role.as_deref(), Some("user"));
assert!(contents[0].parts[0].text.is_some());
assert_eq!(contents[1].role.as_deref(), Some("model"));
assert!(contents[1].parts[0].function_call.is_some());
assert_eq!(contents[2].role.as_deref(), Some("user"));
assert!(contents[2].parts[0].function_response.is_some());
let fr = contents[2].parts[0].function_response.as_ref().unwrap();
assert_eq!(fr.name, "get_weather");
assert_eq!(contents[3].role.as_deref(), Some("model"));
assert!(contents[3].parts[0].text.is_some());
}
#[test]
fn test_profile_table_ordering_invariant() {
for (i, profile_i) in MODEL_PROFILES.iter().enumerate() {
assert!(
!profile_i.prefix.is_empty(),
"profile at index {} has empty prefix",
i
);
for (j, profile_j) in MODEL_PROFILES.iter().enumerate().skip(i + 1) {
let a = profile_i.prefix; let b = profile_j.prefix; if b.starts_with(a) {
panic!(
"Profile ordering violation: '{}' at index {} is more specific than \
'{}' at index {} but appears later. Move more-specific prefixes before \
less-specific ones so the first-match lookup returns the right row.",
b, j, a, i
);
}
}
}
}
#[test]
fn test_lookup_profile_most_specific_wins() {
let p = GeminiProvider::lookup_profile("gemini-2.5-flash-lite");
assert_eq!(
p.prefix, "gemini-2.5-flash-lite",
"flash-lite prefix must match first"
);
assert!(!p.thinks_by_default, "Flash-Lite must not think by default");
let p2 = GeminiProvider::lookup_profile("gemini-2.5-flash");
assert_eq!(p2.prefix, "gemini-2.5-flash", "flash prefix must match");
assert!(p2.thinks_by_default, "Flash must think by default");
let p3 = GeminiProvider::lookup_profile("gemini-3.1-pro-preview");
assert_eq!(p3.prefix, "gemini-3.1-", "3.1 prefix must match before 3-");
}
#[test]
fn test_lookup_profile_strips_provider_prefix() {
let bare = GeminiProvider::lookup_profile("gemini-2.5-flash");
let prefixed = GeminiProvider::lookup_profile("vertexai:gemini-2.5-flash");
assert_eq!(
bare.prefix, prefixed.prefix,
"provider prefix should not affect the resolved model profile"
);
}
#[test]
fn test_lookup_profile_unknown_returns_default() {
let p = GeminiProvider::lookup_profile("totally-unknown-model-9000");
assert_eq!(p.prefix, "", "unknown model must return DEFAULT_PROFILE");
assert!(
matches!(p.thinking, ThinkingStyle::None),
"default profile must not claim thinking support"
);
}
#[test]
fn test_build_thinking_config_uses_style_not_name() {
use crate::traits::CompletionOptions;
let flash = GeminiProvider::new("k").with_model("gemini-2.5-flash");
let opts = CompletionOptions {
gemini_include_thoughts: Some(true),
..Default::default()
};
let cfg = flash
.build_thinking_config(&opts)
.expect("should produce config");
assert!(
cfg.thinking_budget.is_some(),
"2.5 must use thinking_budget"
);
assert!(
cfg.thinking_level.is_none(),
"2.5 must not set thinking_level"
);
let f3 = GeminiProvider::new("k").with_model("gemini-3-flash");
let cfg3 = f3
.build_thinking_config(&opts)
.expect("should produce config");
assert!(
cfg3.thinking_level.is_some(),
"Gemini 3 must use thinking_level"
);
assert!(
cfg3.thinking_budget.is_none(),
"Gemini 3 must not set thinking_budget"
);
let lite = GeminiProvider::new("k").with_model("gemini-2.5-flash-lite");
let cfg_lite = lite
.build_thinking_config(&opts)
.expect("flash-lite should accept opt-in");
assert!(cfg_lite.thinking_budget.is_some());
}
#[test]
fn test_build_thinking_config_none_for_no_thinking_model() {
use crate::traits::CompletionOptions;
let provider = GeminiProvider::new("k").with_model("gemini-1.5-flash");
let opts = CompletionOptions {
gemini_include_thoughts: Some(true),
gemini_thinking_budget: Some(1024),
..Default::default()
};
assert!(
provider.build_thinking_config(&opts).is_none(),
"1.5 models must not emit a ThinkingConfig even when caller opts in"
);
}
#[test]
fn test_build_thinking_config_none_when_not_requested() {
use crate::traits::CompletionOptions;
let provider = GeminiProvider::new("k").with_model("gemini-2.5-flash");
let cfg = provider.build_thinking_config(&CompletionOptions::default());
assert!(
cfg.is_none(),
"No thinking fields → None so the API uses its own model defaults"
);
}
#[test]
fn test_flash_lite_profile_correctness() {
let p = GeminiProvider::lookup_profile("gemini-2.5-flash-lite");
assert!(!p.thinks_by_default);
assert!(!p.auto_preview_suffix);
assert!(!p.requires_global_vertex);
assert_eq!(p.context_length, 1_048_576);
assert!(matches!(
p.thinking,
ThinkingStyle::Budget {
min: 512,
max: 24_576
}
));
}
#[test]
fn test_gemini31_profile_correctness() {
let p = GeminiProvider::lookup_profile("gemini-3.1-pro-preview");
assert!(p.auto_preview_suffix || p.prefix == "gemini-3.1-"); assert_eq!(p.prefix, "gemini-3.1-");
assert!(p.requires_global_vertex);
assert!(matches!(p.thinking, ThinkingStyle::Level));
assert!(p.thinks_by_default);
}
#[test]
fn test_apply_generation_options_all_fields() {
use crate::traits::CompletionOptions;
let opts = CompletionOptions {
max_tokens: Some(512),
temperature: Some(0.5),
top_p: Some(0.9),
stop: Some(vec!["END".to_string()]),
response_format: Some("json_object".to_string()),
..Default::default()
};
let mut config = GenerationConfig::default();
GeminiProvider::apply_generation_options(&mut config, &opts);
assert_eq!(config.max_output_tokens, Some(512));
let temp = config.temperature.expect("temperature must be set");
assert!((temp - 0.5_f32).abs() < 0.001, "temperature mismatch");
let top_p = config.top_p.expect("top_p must be set");
assert!((top_p - 0.9_f32).abs() < 0.001, "top_p mismatch");
assert_eq!(config.stop_sequences, Some(vec!["END".to_string()]));
assert_eq!(
config.response_mime_type.as_deref(),
Some("application/json")
);
}
#[test]
fn test_apply_generation_options_missing_fields_not_overwritten() {
use crate::traits::CompletionOptions;
let mut config = GenerationConfig {
max_output_tokens: Some(999), ..Default::default()
};
GeminiProvider::apply_generation_options(&mut config, &CompletionOptions::default());
assert_eq!(
config.max_output_tokens,
Some(999),
"pre-existing value must not be clobbered"
);
assert!(config.temperature.is_none());
}
#[test]
fn test_context_length_via_profile() {
assert_eq!(
GeminiProvider::context_length_for_model("gemini-3-flash"),
2_000_000
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-3.1-pro-preview"),
2_000_000
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-2.5-flash-lite"),
1_048_576
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-2.5-pro"),
1_048_576
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-2.0-flash"),
1_048_576
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-1.5-pro"),
2_000_000
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-1.5-flash"),
1_000_000
);
assert_eq!(
GeminiProvider::context_length_for_model("gemini-1.0-pro"),
32_000
);
}
#[test]
fn test_prompt_feedback_deserialization() {
let json = r#"{
"candidates": null,
"promptFeedback": {
"blockReason": "SAFETY",
"safetyRatings": []
}
}"#;
let resp: GenerateContentResponse = serde_json::from_str(json).unwrap();
assert!(resp.candidates.is_none());
let pf = resp
.prompt_feedback
.expect("promptFeedback must deserialize");
assert_eq!(pf.block_reason.as_deref(), Some("SAFETY"));
}
#[test]
fn test_gemini_embedding_2_dimension() {
assert_eq!(
GeminiProvider::dimension_for_model("gemini-embedding-2-preview"),
3072
);
assert_eq!(
GeminiProvider::dimension_for_model("gemini-embedding-001"),
3072
);
assert_eq!(
GeminiProvider::dimension_for_model("text-embedding-004"),
768
);
}
}