use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, LazyLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
use reqwest::{Client as HttpClient, Proxy};
use tokio::sync::OnceCell;
use crate::auth::OAuthTokenProvider;
use crate::error::{Error, Result};
use google_cloud_auth::credentials::{
Builder as AuthBuilder, CacheableResource, Credentials as GoogleCredentials,
};
use http::Extensions;
use rust_genai_types::http::HttpRetryOptions;
const X_GOOG_API_CLIENT_HEADER: &str = "x-goog-api-client";
const SDK_USAGE_HEADER_VALUE: &str = concat!(
"google-genai-sdk/",
env!("CARGO_PKG_VERSION"),
" gl-rust/unknown"
);
#[derive(Clone)]
pub struct Client {
inner: Arc<ClientInner>,
}
pub(crate) struct ClientInner {
pub http: HttpClient,
pub config: ClientConfig,
pub api_client: ApiClient,
pub(crate) auth_provider: Option<AuthProvider>,
}
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub api_key: Option<String>,
pub backend: Backend,
pub vertex_config: Option<VertexConfig>,
pub http_options: HttpOptions,
pub credentials: Credentials,
pub auth_scopes: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Backend {
GeminiApi,
VertexAi,
}
#[derive(Debug, Clone)]
pub enum Credentials {
ApiKey(String),
OAuth {
client_secret_path: PathBuf,
token_cache_path: Option<PathBuf>,
},
ApplicationDefault,
}
#[derive(Debug, Clone)]
pub struct VertexConfig {
pub project: String,
pub location: String,
pub credentials: Option<VertexCredentials>,
}
#[derive(Debug, Clone)]
pub struct VertexCredentials {
pub access_token: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct HttpOptions {
pub timeout: Option<u64>,
pub proxy: Option<String>,
pub headers: HashMap<String, String>,
pub base_url: Option<String>,
pub api_version: Option<String>,
pub retry_options: Option<HttpRetryOptions>,
}
impl Client {
pub fn new(api_key: impl Into<String>) -> Result<Self> {
Self::builder()
.api_key(api_key)
.backend(Backend::GeminiApi)
.build()
}
pub fn from_env() -> Result<Self> {
let vertex_override = env_flag("GOOGLE_GENAI_USE_VERTEXAI");
let vertex_project = first_nonempty_env(&["GOOGLE_CLOUD_PROJECT"]);
let vertex_location = first_nonempty_env(&["GOOGLE_CLOUD_LOCATION"]);
let api_key = first_nonempty_env(&["GEMINI_API_KEY", "GOOGLE_API_KEY"]);
let has_complete_vertex_env = vertex_project.is_some() && vertex_location.is_some();
let use_vertex = match vertex_override {
Some(flag) => flag,
None => has_complete_vertex_env && api_key.is_none(),
};
let mut builder = if use_vertex {
let mut builder = Self::builder().backend(Backend::VertexAi);
if let Some(project) = vertex_project {
builder = builder.vertex_project(project);
}
if let Some(location) = vertex_location {
builder = builder.vertex_location(location);
}
builder
} else {
let api_key = api_key.ok_or_else(|| Error::InvalidConfig {
message: "GEMINI_API_KEY or GOOGLE_API_KEY not found".into(),
})?;
Self::builder().api_key(api_key).backend(Backend::GeminiApi)
};
let base_url_envs: &[&str] = if use_vertex {
&["GOOGLE_GENAI_BASE_URL", "GENAI_BASE_URL"]
} else {
&["GOOGLE_GENAI_BASE_URL", "GENAI_BASE_URL", "GEMINI_BASE_URL"]
};
if let Some(base_url) = first_nonempty_env(base_url_envs) {
builder = builder.base_url(base_url);
}
if let Some(api_version) =
first_nonempty_env(&["GOOGLE_GENAI_API_VERSION", "GENAI_API_VERSION"])
{
builder = builder.api_version(api_version);
}
builder.build()
}
pub fn new_vertex(project: impl Into<String>, location: impl Into<String>) -> Result<Self> {
Self::builder()
.backend(Backend::VertexAi)
.vertex_project(project)
.vertex_location(location)
.build()
}
pub fn with_oauth(client_secret_path: impl AsRef<Path>) -> Result<Self> {
Self::builder()
.credentials(Credentials::OAuth {
client_secret_path: client_secret_path.as_ref().to_path_buf(),
token_cache_path: None,
})
.build()
}
pub fn with_adc() -> Result<Self> {
Self::builder()
.credentials(Credentials::ApplicationDefault)
.build()
}
#[must_use]
pub fn builder() -> ClientBuilder {
ClientBuilder::default()
}
#[must_use]
pub fn models(&self) -> crate::models::Models {
crate::models::Models::new(self.inner.clone())
}
#[must_use]
pub fn chats(&self) -> crate::chats::Chats {
crate::chats::Chats::new(self.inner.clone())
}
#[must_use]
pub fn files(&self) -> crate::files::Files {
crate::files::Files::new(self.inner.clone())
}
#[must_use]
pub fn file_search_stores(&self) -> crate::file_search_stores::FileSearchStores {
crate::file_search_stores::FileSearchStores::new(self.inner.clone())
}
#[must_use]
pub fn documents(&self) -> crate::documents::Documents {
crate::documents::Documents::new(self.inner.clone())
}
#[must_use]
pub fn live(&self) -> crate::live::Live {
crate::live::Live::new(self.inner.clone())
}
#[must_use]
pub fn live_music(&self) -> crate::live_music::LiveMusic {
crate::live_music::LiveMusic::new(self.inner.clone())
}
#[must_use]
pub fn caches(&self) -> crate::caches::Caches {
crate::caches::Caches::new(self.inner.clone())
}
#[must_use]
pub fn batches(&self) -> crate::batches::Batches {
crate::batches::Batches::new(self.inner.clone())
}
#[must_use]
pub fn tunings(&self) -> crate::tunings::Tunings {
crate::tunings::Tunings::new(self.inner.clone())
}
#[must_use]
pub fn operations(&self) -> crate::operations::Operations {
crate::operations::Operations::new(self.inner.clone())
}
#[must_use]
pub fn auth_tokens(&self) -> crate::tokens::AuthTokens {
crate::tokens::AuthTokens::new(self.inner.clone())
}
#[must_use]
pub fn tokens(&self) -> crate::tokens::Tokens {
self.auth_tokens()
}
#[must_use]
pub fn interactions(&self) -> crate::interactions::Interactions {
crate::interactions::Interactions::new(self.inner.clone())
}
#[must_use]
pub fn webhooks(&self) -> crate::webhooks::Webhooks {
crate::webhooks::Webhooks::new(self.inner.clone())
}
#[must_use]
pub fn deep_research(&self) -> crate::deep_research::DeepResearch {
crate::deep_research::DeepResearch::new(self.inner.clone())
}
}
#[derive(Default)]
pub struct ClientBuilder {
api_key: Option<String>,
credentials: Option<Credentials>,
backend: Option<Backend>,
vertex_project: Option<String>,
vertex_location: Option<String>,
http_options: HttpOptions,
auth_scopes: Option<Vec<String>>,
}
impl ClientBuilder {
#[must_use]
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
#[must_use]
pub fn credentials(mut self, credentials: Credentials) -> Self {
self.credentials = Some(credentials);
self
}
#[must_use]
pub const fn backend(mut self, backend: Backend) -> Self {
self.backend = Some(backend);
self
}
#[must_use]
pub fn vertex_project(mut self, project: impl Into<String>) -> Self {
self.vertex_project = Some(project.into());
self
}
#[must_use]
pub fn vertex_location(mut self, location: impl Into<String>) -> Self {
self.vertex_location = Some(location.into());
self
}
#[must_use]
pub const fn timeout(mut self, secs: u64) -> Self {
self.http_options.timeout = Some(secs);
self
}
#[must_use]
pub fn proxy(mut self, url: impl Into<String>) -> Self {
self.http_options.proxy = Some(url.into());
self
}
#[must_use]
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.http_options.headers.insert(key.into(), value.into());
self
}
#[must_use]
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
self.http_options.base_url = Some(base_url.into());
self
}
#[must_use]
pub fn api_version(mut self, api_version: impl Into<String>) -> Self {
self.http_options.api_version = Some(api_version.into());
self
}
#[must_use]
pub fn retry_options(mut self, retry_options: HttpRetryOptions) -> Self {
self.http_options.retry_options = Some(retry_options);
self
}
#[must_use]
pub fn auth_scopes(mut self, scopes: Vec<String>) -> Self {
self.auth_scopes = Some(scopes);
self
}
pub fn build(self) -> Result<Client> {
let Self {
api_key,
credentials,
backend,
vertex_project,
vertex_location,
http_options,
auth_scopes,
} = self;
let backend = Self::resolve_backend(
backend,
vertex_project.as_deref(),
vertex_location.as_deref(),
);
Self::validate_vertex_config(
backend,
vertex_project.as_deref(),
vertex_location.as_deref(),
)?;
let credentials = Self::resolve_credentials(backend, api_key.as_deref(), credentials)?;
let headers = Self::build_headers(&http_options, backend, &credentials)?;
let http = Self::build_http_client(&http_options, headers)?;
let auth_scopes = auth_scopes.unwrap_or_else(|| default_auth_scopes(backend));
let api_key = match &credentials {
Credentials::ApiKey(key) => Some(key.clone()),
_ => None,
};
let vertex_config = Self::build_vertex_config(backend, vertex_project, vertex_location)?;
let config = ClientConfig {
api_key,
backend,
vertex_config,
http_options,
credentials: credentials.clone(),
auth_scopes,
};
let auth_provider = build_auth_provider(&credentials)?;
let api_client = ApiClient::new(&config);
Ok(Client {
inner: Arc::new(ClientInner {
http,
config,
api_client,
auth_provider,
}),
})
}
fn resolve_backend(
backend: Option<Backend>,
vertex_project: Option<&str>,
vertex_location: Option<&str>,
) -> Backend {
backend.unwrap_or_else(|| {
if vertex_project.is_some() || vertex_location.is_some() {
Backend::VertexAi
} else {
Backend::GeminiApi
}
})
}
fn validate_vertex_config(
backend: Backend,
vertex_project: Option<&str>,
vertex_location: Option<&str>,
) -> Result<()> {
if backend == Backend::VertexAi && (vertex_project.is_none() || vertex_location.is_none()) {
return Err(Error::InvalidConfig {
message: "Project and location required for Vertex AI".into(),
});
}
Ok(())
}
fn resolve_credentials(
backend: Backend,
api_key: Option<&str>,
credentials: Option<Credentials>,
) -> Result<Credentials> {
if credentials.is_some()
&& api_key.is_some()
&& !matches!(credentials, Some(Credentials::ApiKey(_)))
{
return Err(Error::InvalidConfig {
message: "API key cannot be combined with OAuth/ADC credentials".into(),
});
}
let credentials = match credentials {
Some(credentials) => credentials,
None => {
if let Some(api_key) = api_key {
Credentials::ApiKey(api_key.to_string())
} else if backend == Backend::VertexAi {
Credentials::ApplicationDefault
} else {
return Err(Error::InvalidConfig {
message: "API key or OAuth credentials required for Gemini API".into(),
});
}
}
};
if backend == Backend::VertexAi && matches!(credentials, Credentials::ApiKey(_)) {
return Err(Error::InvalidConfig {
message: "Vertex AI does not support API key authentication".into(),
});
}
Ok(credentials)
}
fn build_headers(
http_options: &HttpOptions,
backend: Backend,
credentials: &Credentials,
) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
for (key, value) in &http_options.headers {
let name =
HeaderName::from_bytes(key.as_bytes()).map_err(|_| Error::InvalidConfig {
message: format!("Invalid header name: {key}"),
})?;
let value = HeaderValue::from_str(value).map_err(|_| Error::InvalidConfig {
message: format!("Invalid header value for {key}"),
})?;
headers.insert(name, value);
}
if backend == Backend::GeminiApi {
let api_key = match credentials {
Credentials::ApiKey(key) => key.as_str(),
_ => "",
};
let header_name = HeaderName::from_static("x-goog-api-key");
if !api_key.is_empty() && !headers.contains_key(&header_name) {
let mut header_value =
HeaderValue::from_str(api_key).map_err(|_| Error::InvalidConfig {
message: "Invalid API key value".into(),
})?;
header_value.set_sensitive(true);
headers.insert(header_name, header_value);
}
}
Ok(headers)
}
fn build_http_client(http_options: &HttpOptions, headers: HeaderMap) -> Result<HttpClient> {
let mut http_builder = HttpClient::builder();
if let Some(timeout) = http_options.timeout {
http_builder = http_builder.timeout(Duration::from_secs(timeout));
}
if let Some(proxy_url) = &http_options.proxy {
let proxy = Proxy::all(proxy_url).map_err(|e| Error::InvalidConfig {
message: format!("Invalid proxy: {e}"),
})?;
http_builder = http_builder.proxy(proxy);
}
if !headers.is_empty() {
http_builder = http_builder.default_headers(headers);
}
Ok(http_builder.build()?)
}
fn build_vertex_config(
backend: Backend,
vertex_project: Option<String>,
vertex_location: Option<String>,
) -> Result<Option<VertexConfig>> {
if backend != Backend::VertexAi {
return Ok(None);
}
let project = vertex_project.ok_or_else(|| Error::InvalidConfig {
message: "Project and location required for Vertex AI".into(),
})?;
let location = vertex_location.ok_or_else(|| Error::InvalidConfig {
message: "Project and location required for Vertex AI".into(),
})?;
Ok(Some(VertexConfig {
project,
location,
credentials: None,
}))
}
}
fn build_auth_provider(credentials: &Credentials) -> Result<Option<AuthProvider>> {
match credentials {
Credentials::ApiKey(_) => Ok(None),
Credentials::OAuth {
client_secret_path,
token_cache_path,
} => Ok(Some(AuthProvider::OAuth(Arc::new(
OAuthTokenProvider::from_paths(client_secret_path.clone(), token_cache_path.clone())?,
)))),
Credentials::ApplicationDefault => Ok(Some(AuthProvider::ApplicationDefault(Arc::new(
OnceCell::new(),
)))),
}
}
#[derive(Clone)]
pub(crate) enum AuthProvider {
OAuth(Arc<OAuthTokenProvider>),
ApplicationDefault(Arc<OnceCell<Arc<GoogleCredentials>>>),
}
impl AuthProvider {
async fn headers(&self, scopes: &[&str]) -> Result<HeaderMap> {
match self {
Self::OAuth(provider) => {
let token = provider.token().await?;
let mut header =
HeaderValue::from_str(&format!("Bearer {token}")).map_err(|_| Error::Auth {
message: "Invalid OAuth access token".into(),
})?;
header.set_sensitive(true);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, header);
Ok(headers)
}
Self::ApplicationDefault(cell) => {
let credentials = cell
.get_or_try_init(|| async {
AuthBuilder::default()
.with_scopes(scopes.iter().copied())
.build()
.map(Arc::new)
.map_err(|err| Error::Auth {
message: format!("ADC init failed: {err}"),
})
})
.await?;
let headers = credentials
.headers(Extensions::new())
.await
.map_err(|err| Error::Auth {
message: format!("ADC header fetch failed: {err}"),
})?;
match headers {
CacheableResource::New { data, .. } => Ok(data),
CacheableResource::NotModified => Err(Error::Auth {
message: "ADC header fetch returned NotModified without cached headers"
.into(),
}),
}
}
}
}
}
const DEFAULT_RETRY_ATTEMPTS: u32 = 5; const DEFAULT_RETRY_INITIAL_DELAY_SECS: f64 = 1.0;
const DEFAULT_RETRY_MAX_DELAY_SECS: f64 = 60.0;
const DEFAULT_RETRY_EXP_BASE: f64 = 2.0;
const DEFAULT_RETRY_JITTER: f64 = 1.0;
const DEFAULT_RETRY_HTTP_STATUS_CODES: [u16; 6] = [408, 429, 500, 502, 503, 504];
static DEFAULT_HTTP_RETRY_OPTIONS: LazyLock<HttpRetryOptions> =
LazyLock::new(|| HttpRetryOptions {
attempts: Some(DEFAULT_RETRY_ATTEMPTS),
initial_delay: Some(DEFAULT_RETRY_INITIAL_DELAY_SECS),
max_delay: Some(DEFAULT_RETRY_MAX_DELAY_SECS),
exp_base: Some(DEFAULT_RETRY_EXP_BASE),
jitter: Some(DEFAULT_RETRY_JITTER),
http_status_codes: Some(DEFAULT_RETRY_HTTP_STATUS_CODES.to_vec()),
});
#[derive(Debug, Clone, Copy)]
pub(crate) struct RetryMetadata {
pub attempts: u32,
pub retryable: bool,
}
impl ClientInner {
pub async fn send(&self, request: reqwest::RequestBuilder) -> Result<reqwest::Response> {
self.send_with_http_options(request, None).await
}
pub async fn send_with_http_options(
&self,
request: reqwest::RequestBuilder,
request_http_options: Option<&rust_genai_types::http::HttpOptions>,
) -> Result<reqwest::Response> {
let retry_options = request_http_options
.and_then(|options| options.retry_options.as_ref())
.or(self.config.http_options.retry_options.as_ref())
.unwrap_or(&DEFAULT_HTTP_RETRY_OPTIONS);
let request_template = request.build()?;
self.execute_with_retry(request_template, retry_options)
.await
}
async fn execute_once(&self, mut request: reqwest::Request) -> Result<reqwest::Response> {
self.prepare_request(&mut request).await?;
Ok(self.http.execute(request).await?)
}
async fn execute_with_retry(
&self,
request_template: reqwest::Request,
retry_options: &HttpRetryOptions,
) -> Result<reqwest::Response> {
let attempts = retry_options.attempts.unwrap_or(DEFAULT_RETRY_ATTEMPTS);
let retryable_codes: &[u16] = retry_options
.http_status_codes
.as_deref()
.unwrap_or(&DEFAULT_RETRY_HTTP_STATUS_CODES);
if attempts <= 1 {
let mut response = self.execute_once(request_template).await?;
if !response.status().is_success() {
attach_retry_metadata_for_codes(&mut response, 1, retryable_codes);
}
return Ok(response);
}
if request_template.try_clone().is_none() {
let mut response = self.execute_once(request_template).await?;
if !response.status().is_success() {
attach_retry_metadata_for_codes(&mut response, 1, retryable_codes);
}
return Ok(response);
}
for attempt in 0..attempts {
let request = request_template
.try_clone()
.expect("request_template is cloneable");
let response = self.execute_once(request).await?;
if response.status().is_success() {
return Ok(response);
}
let status = response.status().as_u16();
let should_retry = retryable_codes.contains(&status);
let is_last_attempt = attempt + 1 >= attempts;
if !should_retry || is_last_attempt {
let mut response = response;
attach_retry_metadata(&mut response, attempt + 1, should_retry);
return Ok(response);
}
let delay = bounded_retry_delay_secs(
retry_options,
attempt,
retry_after_delay_secs(response.headers()),
);
drop(response);
if delay > 0.0 {
tokio::time::sleep(Duration::from_secs_f64(delay)).await;
}
}
unreachable!("retry loop must return a response");
}
async fn prepare_request(&self, request: &mut reqwest::Request) -> Result<()> {
if let Some(headers) = self.auth_headers().await? {
for (name, value) in &headers {
if request.headers().contains_key(name) {
continue;
}
let mut value = value.clone();
if name == AUTHORIZATION {
value.set_sensitive(true);
}
request.headers_mut().insert(name.clone(), value);
}
}
if self.config.backend == Backend::GeminiApi {
append_sdk_usage_header(request.headers_mut())?;
}
#[cfg(feature = "mcp")]
crate::mcp::append_mcp_usage_header(request.headers_mut())?;
Ok(())
}
async fn auth_headers(&self) -> Result<Option<HeaderMap>> {
let Some(provider) = &self.auth_provider else {
return Ok(None);
};
let scopes: Vec<&str> = self.config.auth_scopes.iter().map(String::as_str).collect();
let headers = provider.headers(&scopes).await?;
Ok(Some(headers))
}
}
fn append_sdk_usage_header(headers: &mut HeaderMap) -> Result<()> {
let header_name = HeaderName::from_static(X_GOOG_API_CLIENT_HEADER);
let existing_values = headers
.get_all(&header_name)
.iter()
.map(|value| {
value
.to_str()
.map(str::trim)
.map(str::to_string)
.map_err(|_| Error::InvalidConfig {
message: "Invalid x-goog-api-client header value".into(),
})
})
.collect::<Result<Vec<_>>>()?;
let existing = existing_values
.into_iter()
.filter(|value| !value.is_empty())
.collect::<Vec<_>>()
.join(" ");
let combined = if existing.contains(SDK_USAGE_HEADER_VALUE) {
existing
} else if existing.is_empty() {
SDK_USAGE_HEADER_VALUE.to_string()
} else {
format!("{SDK_USAGE_HEADER_VALUE} {existing}")
};
let value = HeaderValue::from_str(&combined).map_err(|_| Error::InvalidConfig {
message: "Invalid x-goog-api-client header value".into(),
})?;
headers.insert(header_name, value);
Ok(())
}
fn first_nonempty_env(names: &[&str]) -> Option<String> {
names.iter().find_map(|name| {
std::env::var(name)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
})
}
fn env_flag(name: &str) -> Option<bool> {
let value = std::env::var(name).ok()?;
match value.trim().to_ascii_lowercase().as_str() {
"1" | "true" | "yes" | "on" => Some(true),
"0" | "false" | "no" | "off" => Some(false),
_ => None,
}
}
fn attach_retry_metadata(response: &mut reqwest::Response, attempts: u32, retryable: bool) {
response.extensions_mut().insert(RetryMetadata {
attempts,
retryable,
});
}
fn attach_retry_metadata_for_codes(
response: &mut reqwest::Response,
attempts: u32,
retryable_codes: &[u16],
) {
let retryable = retryable_codes.contains(&response.status().as_u16());
attach_retry_metadata(response, attempts, retryable);
}
fn retry_after_delay_secs(headers: &HeaderMap) -> Option<f64> {
let retry_after = headers
.get(reqwest::header::RETRY_AFTER)
.and_then(|value| value.to_str().ok())
.map(str::trim)?;
retry_after
.parse::<f64>()
.ok()
.map(|delay| delay.max(0.0))
.or_else(|| {
httpdate::parse_http_date(retry_after).ok().map(|deadline| {
deadline
.duration_since(SystemTime::now())
.unwrap_or_default()
.as_secs_f64()
})
})
}
fn bounded_retry_delay_secs(
options: &HttpRetryOptions,
retry_index: u32,
retry_after_secs: Option<f64>,
) -> f64 {
let delay = retry_after_secs.unwrap_or_else(|| retry_delay_secs(options, retry_index));
let max_delay = options
.max_delay
.unwrap_or(DEFAULT_RETRY_MAX_DELAY_SECS)
.max(0.0);
delay.min(max_delay)
}
fn retry_delay_secs(options: &HttpRetryOptions, retry_index: u32) -> f64 {
let initial = options
.initial_delay
.unwrap_or(DEFAULT_RETRY_INITIAL_DELAY_SECS)
.max(0.0);
let max_delay = options
.max_delay
.unwrap_or(DEFAULT_RETRY_MAX_DELAY_SECS)
.max(0.0);
let exp_base = options.exp_base.unwrap_or(DEFAULT_RETRY_EXP_BASE).max(0.0);
let jitter = options.jitter.unwrap_or(DEFAULT_RETRY_JITTER).max(0.0);
let exp_delay = if exp_base == 0.0 {
0.0
} else {
initial * exp_base.powf(retry_index as f64)
};
let base_delay = if max_delay > 0.0 {
exp_delay.min(max_delay)
} else {
exp_delay
};
let jitter_delay = if jitter > 0.0 {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos() as f64;
let frac = (nanos / 1_000_000_000.0).clamp(0.0, 1.0);
frac * jitter
} else {
0.0
};
let delay = base_delay + jitter_delay;
if max_delay > 0.0 {
delay.min(max_delay)
} else {
delay
}
}
fn default_auth_scopes(backend: Backend) -> Vec<String> {
match backend {
Backend::VertexAi => vec!["https://www.googleapis.com/auth/cloud-platform".into()],
Backend::GeminiApi => vec![
"https://www.googleapis.com/auth/generative-language".into(),
"https://www.googleapis.com/auth/generative-language.retriever".into(),
],
}
}
pub(crate) struct ApiClient {
pub base_url: String,
pub api_version: String,
}
impl ApiClient {
pub fn new(config: &ClientConfig) -> Self {
let base_url = config.http_options.base_url.as_deref().map_or_else(
|| match config.backend {
Backend::VertexAi => {
let location = config
.vertex_config
.as_ref()
.map_or("", |cfg| cfg.location.as_str());
if location.is_empty() {
"https://aiplatform.googleapis.com/".to_string()
} else {
format!("https://{location}-aiplatform.googleapis.com/")
}
}
Backend::GeminiApi => "https://generativelanguage.googleapis.com/".to_string(),
},
normalize_base_url,
);
let api_version =
config
.http_options
.api_version
.clone()
.unwrap_or_else(|| match config.backend {
Backend::VertexAi => "v1beta1".to_string(),
Backend::GeminiApi => "v1beta".to_string(),
});
Self {
base_url,
api_version,
}
}
}
fn normalize_base_url(base_url: &str) -> String {
let mut value = base_url.trim().to_string();
if !value.ends_with('/') {
value.push('/');
}
value
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::with_env;
use bytes::Bytes;
use futures_util::stream;
use std::path::PathBuf;
use std::time::SystemTime;
use tempfile::tempdir;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[test]
fn test_client_from_api_key() {
let client = Client::new("test-api-key").unwrap();
assert_eq!(client.inner.config.backend, Backend::GeminiApi);
}
#[test]
fn test_client_builder() {
let client = Client::builder()
.api_key("test-key")
.timeout(30)
.build()
.unwrap();
assert!(client.inner.config.api_key.is_some());
}
#[test]
fn test_vertex_ai_config() {
let client = Client::new_vertex("my-project", "us-central1").unwrap();
assert_eq!(client.inner.config.backend, Backend::VertexAi);
assert_eq!(
client.inner.api_client.base_url,
"https://us-central1-aiplatform.googleapis.com/"
);
}
#[test]
fn test_base_url_normalization() {
let client = Client::builder()
.api_key("test-key")
.base_url("https://example.com")
.build()
.unwrap();
assert_eq!(client.inner.api_client.base_url, "https://example.com/");
}
#[test]
fn test_from_env_reads_overrides() {
with_env(
&[
("GEMINI_API_KEY", Some("env-key")),
("GENAI_BASE_URL", Some("https://env.example.com")),
("GENAI_API_VERSION", Some("v99")),
("GOOGLE_GENAI_API_VERSION", None),
("GOOGLE_API_KEY", None),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(client.inner.api_client.base_url, "https://env.example.com/");
assert_eq!(client.inner.api_client.api_version, "v99");
},
);
}
#[test]
fn test_from_env_ignores_gemini_base_url_for_vertex() {
with_env(
&[
("GEMINI_API_KEY", None),
("GOOGLE_API_KEY", None),
("GOOGLE_GENAI_USE_VERTEXAI", Some("true")),
("GOOGLE_CLOUD_PROJECT", Some("vertex-project")),
("GOOGLE_CLOUD_LOCATION", Some("us-central1")),
("GEMINI_BASE_URL", Some("https://gemini-only.example.com")),
("GENAI_BASE_URL", None),
("GOOGLE_GENAI_BASE_URL", None),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(client.inner.config.backend, Backend::VertexAi);
assert_eq!(
client.inner.api_client.base_url,
"https://us-central1-aiplatform.googleapis.com/"
);
},
);
}
#[test]
fn test_from_env_ignores_empty_overrides() {
with_env(
&[
("GEMINI_API_KEY", Some("env-key")),
("GENAI_BASE_URL", Some(" ")),
("GENAI_API_VERSION", Some("")),
("GOOGLE_GENAI_API_VERSION", None),
("GOOGLE_API_KEY", None),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(
client.inner.api_client.base_url,
"https://generativelanguage.googleapis.com/"
);
assert_eq!(client.inner.api_client.api_version, "v1beta");
},
);
}
#[test]
fn test_from_env_missing_key_errors() {
with_env(
&[
("GEMINI_API_KEY", None),
("GOOGLE_API_KEY", None),
("GENAI_BASE_URL", None),
("GOOGLE_GENAI_USE_VERTEXAI", None),
("GOOGLE_CLOUD_PROJECT", None),
("GOOGLE_CLOUD_LOCATION", None),
],
|| {
let result = Client::from_env();
assert!(result.is_err());
},
);
}
#[test]
fn test_from_env_google_api_key_fallback() {
with_env(
&[
("GEMINI_API_KEY", None),
("GOOGLE_API_KEY", Some("google-key")),
("GOOGLE_GENAI_USE_VERTEXAI", None),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(client.inner.config.api_key.as_deref(), Some("google-key"));
},
);
}
#[test]
fn test_from_env_supports_official_vertex_envs() {
with_env(
&[
("GEMINI_API_KEY", Some("env-key")),
("GOOGLE_API_KEY", None),
("GOOGLE_GENAI_USE_VERTEXAI", Some("true")),
("GOOGLE_CLOUD_PROJECT", Some("vertex-project")),
("GOOGLE_CLOUD_LOCATION", Some("us-central1")),
("GOOGLE_GENAI_API_VERSION", Some("v1")),
("GENAI_API_VERSION", Some("v1beta")),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(client.inner.config.backend, Backend::VertexAi);
assert!(matches!(
client.inner.config.credentials,
Credentials::ApplicationDefault
));
assert_eq!(client.inner.config.api_key, None);
assert_eq!(
client.inner.api_client.base_url,
"https://us-central1-aiplatform.googleapis.com/"
);
assert_eq!(client.inner.api_client.api_version, "v1");
},
);
}
#[test]
fn test_from_env_uses_complete_vertex_env_without_flag_when_api_key_is_absent() {
with_env(
&[
("GEMINI_API_KEY", None),
("GOOGLE_API_KEY", None),
("GOOGLE_GENAI_USE_VERTEXAI", None),
("GOOGLE_CLOUD_PROJECT", Some("vertex-project")),
("GOOGLE_CLOUD_LOCATION", Some("us-central1")),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(client.inner.config.backend, Backend::VertexAi);
assert_eq!(client.inner.config.api_key, None);
assert_eq!(
client.inner.api_client.base_url,
"https://us-central1-aiplatform.googleapis.com/"
);
},
);
}
#[test]
fn test_from_env_prefers_gemini_when_api_key_and_complete_vertex_env_exist() {
with_env(
&[
("GEMINI_API_KEY", Some("env-key")),
("GOOGLE_API_KEY", None),
("GOOGLE_GENAI_USE_VERTEXAI", None),
("GOOGLE_CLOUD_PROJECT", Some("vertex-project")),
("GOOGLE_CLOUD_LOCATION", Some("us-central1")),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(client.inner.config.backend, Backend::GeminiApi);
assert_eq!(client.inner.config.api_key.as_deref(), Some("env-key"));
assert_eq!(
client.inner.api_client.base_url,
"https://generativelanguage.googleapis.com/"
);
},
);
}
#[test]
fn test_from_env_explicit_false_prefers_gemini_even_with_complete_vertex_env() {
with_env(
&[
("GEMINI_API_KEY", Some("env-key")),
("GOOGLE_API_KEY", None),
("GOOGLE_GENAI_USE_VERTEXAI", Some("false")),
("GOOGLE_CLOUD_PROJECT", Some("vertex-project")),
("GOOGLE_CLOUD_LOCATION", Some("us-central1")),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(client.inner.config.backend, Backend::GeminiApi);
assert_eq!(client.inner.config.api_key.as_deref(), Some("env-key"));
assert_eq!(
client.inner.api_client.base_url,
"https://generativelanguage.googleapis.com/"
);
},
);
}
#[test]
fn test_from_env_prefers_gemini_when_vertex_env_is_partial() {
with_env(
&[
("GEMINI_API_KEY", Some("env-key")),
("GOOGLE_API_KEY", None),
("GOOGLE_GENAI_USE_VERTEXAI", None),
("GOOGLE_CLOUD_PROJECT", Some("vertex-project")),
("GOOGLE_CLOUD_LOCATION", None),
],
|| {
let client = Client::from_env().unwrap();
assert_eq!(client.inner.config.backend, Backend::GeminiApi);
assert_eq!(client.inner.config.api_key.as_deref(), Some("env-key"));
assert_eq!(
client.inner.api_client.base_url,
"https://generativelanguage.googleapis.com/"
);
},
);
}
#[test]
fn test_from_env_vertex_requires_project_and_location() {
with_env(
&[
("GOOGLE_GENAI_USE_VERTEXAI", Some("true")),
("GOOGLE_CLOUD_PROJECT", Some("vertex-project")),
("GOOGLE_CLOUD_LOCATION", None),
("GEMINI_API_KEY", None),
("GOOGLE_API_KEY", None),
],
|| {
let result = Client::from_env();
assert!(matches!(result, Err(Error::InvalidConfig { .. })));
},
);
}
#[test]
fn test_bounded_retry_delay_secs_prefers_retry_after_with_cap() {
let options = HttpRetryOptions {
max_delay: Some(2.0),
..Default::default()
};
let delay = bounded_retry_delay_secs(&options, 0, Some(120.0));
assert_eq!(delay, 2.0);
}
#[test]
fn test_bounded_retry_delay_secs_uses_retry_after_when_below_cap() {
let options = HttpRetryOptions {
max_delay: Some(5.0),
..Default::default()
};
let delay = bounded_retry_delay_secs(&options, 0, Some(1.5));
assert_eq!(delay, 1.5);
}
#[test]
fn test_bounded_retry_delay_secs_caps_retry_after_at_zero() {
let options = HttpRetryOptions {
max_delay: Some(0.0),
..Default::default()
};
let delay = bounded_retry_delay_secs(&options, 0, Some(120.0));
assert_eq!(delay, 0.0);
}
#[test]
fn test_bounded_retry_delay_secs_falls_back_to_backoff() {
let options = HttpRetryOptions {
initial_delay: Some(1.0),
max_delay: Some(10.0),
exp_base: Some(2.0),
jitter: Some(0.0),
..Default::default()
};
let delay = bounded_retry_delay_secs(&options, 2, None);
assert_eq!(delay, 4.0);
}
#[test]
fn test_retry_after_delay_secs_parses_http_date() {
let deadline = SystemTime::now() + Duration::from_secs(120);
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::RETRY_AFTER,
HeaderValue::from_str(&httpdate::fmt_http_date(deadline)).unwrap(),
);
let delay = retry_after_delay_secs(&headers).unwrap();
assert!((110.0..=120.0).contains(&delay));
}
#[tokio::test]
async fn test_send_with_http_options_preserves_custom_retry_metadata_without_retries() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/retry-once"))
.respond_with(ResponseTemplate::new(409).set_body_string("conflict"))
.mount(&server)
.await;
let client = Client::new("test-key").unwrap();
let request = client
.inner
.http
.post(format!("{}/retry-once", server.uri()))
.body(reqwest::Body::wrap_stream(stream::once(async {
Ok::<Bytes, std::io::Error>(Bytes::from_static(b"payload"))
})));
let http_options = rust_genai_types::http::HttpOptions {
retry_options: Some(HttpRetryOptions {
attempts: Some(2),
http_status_codes: Some(vec![409]),
initial_delay: Some(0.0),
max_delay: Some(0.0),
exp_base: Some(0.0),
jitter: Some(0.0),
}),
..Default::default()
};
let response = client
.inner
.send_with_http_options(request, Some(&http_options))
.await
.unwrap();
let retry_metadata = response
.extensions()
.get::<RetryMetadata>()
.copied()
.unwrap();
assert_eq!(response.status().as_u16(), 409);
assert_eq!(retry_metadata.attempts, 1);
assert!(retry_metadata.retryable);
}
#[test]
fn test_with_oauth_missing_client_secret_errors() {
let dir = tempdir().unwrap();
let secret_path = dir.path().join("missing_client_secret.json");
let err = Client::with_oauth(&secret_path).err().unwrap();
assert!(matches!(err, Error::InvalidConfig { .. }));
}
#[test]
fn test_with_adc_builds_client() {
let client = Client::with_adc().unwrap();
assert!(matches!(
client.inner.config.credentials,
Credentials::ApplicationDefault
));
}
#[test]
fn test_builder_defaults_to_vertex_when_project_set() {
let client = Client::builder()
.vertex_project("proj")
.vertex_location("loc")
.build()
.unwrap();
assert_eq!(client.inner.config.backend, Backend::VertexAi);
assert!(matches!(
client.inner.config.credentials,
Credentials::ApplicationDefault
));
}
#[test]
fn test_valid_proxy_is_accepted() {
let client = Client::builder()
.api_key("test-key")
.proxy("http://127.0.0.1:8888")
.build();
assert!(client.is_ok());
}
#[test]
fn test_vertex_requires_project_and_location() {
let result = Client::builder().backend(Backend::VertexAi).build();
assert!(result.is_err());
}
#[test]
fn test_api_key_with_oauth_is_invalid() {
let result = Client::builder()
.api_key("test-key")
.credentials(Credentials::OAuth {
client_secret_path: PathBuf::from("client_secret.json"),
token_cache_path: None,
})
.build();
assert!(result.is_err());
}
#[test]
fn test_missing_api_key_for_gemini_errors() {
let result = Client::builder().backend(Backend::GeminiApi).build();
assert!(result.is_err());
}
#[test]
fn test_invalid_header_name_is_rejected() {
let result = Client::builder()
.api_key("test-key")
.header("bad header", "value")
.build();
assert!(result.is_err());
}
#[test]
fn test_invalid_header_value_is_rejected() {
let result = Client::builder()
.api_key("test-key")
.header("x-test", "bad\nvalue")
.build();
assert!(result.is_err());
}
#[test]
fn test_invalid_api_key_value_is_rejected() {
let err = Client::builder().api_key("bad\nkey").build().err().unwrap();
assert!(
matches!(err, Error::InvalidConfig { message } if message.contains("Invalid API key value"))
);
}
#[test]
fn test_invalid_proxy_is_rejected() {
let result = Client::builder()
.api_key("test-key")
.proxy("not a url")
.build();
assert!(result.is_err());
}
#[test]
fn test_vertex_api_key_is_rejected() {
let result = Client::builder()
.backend(Backend::VertexAi)
.vertex_project("proj")
.vertex_location("loc")
.credentials(Credentials::ApiKey("key".into()))
.build();
assert!(result.is_err());
}
#[test]
fn test_default_auth_scopes() {
let gemini = default_auth_scopes(Backend::GeminiApi);
assert!(gemini.iter().any(|s| s.contains("generative-language")));
let vertex = default_auth_scopes(Backend::VertexAi);
assert!(vertex.iter().any(|s| s.contains("cloud-platform")));
}
#[test]
fn test_custom_auth_scopes_override_default() {
let client = Client::builder()
.api_key("test-key")
.auth_scopes(vec!["scope-1".to_string()])
.build()
.unwrap();
assert_eq!(client.inner.config.auth_scopes, vec!["scope-1".to_string()]);
}
#[test]
fn test_append_sdk_usage_header() {
let mut headers = HeaderMap::new();
append_sdk_usage_header(&mut headers).unwrap();
assert_eq!(
headers
.get(X_GOOG_API_CLIENT_HEADER)
.and_then(|value| value.to_str().ok()),
Some(SDK_USAGE_HEADER_VALUE)
);
}
#[test]
fn test_append_sdk_usage_header_preserves_existing_value() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static(X_GOOG_API_CLIENT_HEADER),
HeaderValue::from_static("custom-client/1.0.0"),
);
append_sdk_usage_header(&mut headers).unwrap();
append_sdk_usage_header(&mut headers).unwrap();
assert_eq!(
headers
.get(X_GOOG_API_CLIENT_HEADER)
.and_then(|value| value.to_str().ok()),
Some(concat!(
"google-genai-sdk/",
env!("CARGO_PKG_VERSION"),
" gl-rust/unknown custom-client/1.0.0"
))
);
}
}