pub mod client;
pub mod errors;
pub mod pricing;
pub mod stream;
pub mod translate;
use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
use tracing::instrument;
use tt_shared::{
filter_extra_headers, validate_provider_url, ChatCompletionChunk, ChatCompletionRequest,
ChatCompletionResponse, EmbeddingsRequest, EmbeddingsResponse, ModelInfo, ModelPricing,
Provider, ProviderError, RequestContext,
};
pub use client::ClientConfig;
const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com";
pub struct GeminiProvider {
client: Client,
allow_local: bool,
}
impl GeminiProvider {
pub fn new(cfg: ClientConfig) -> Self {
let client =
client::build_client(&cfg).expect("failed to build reqwest::Client for Gemini adapter");
Self {
client,
allow_local: false,
}
}
#[doc(hidden)]
pub fn new_allow_local(cfg: ClientConfig) -> Self {
let client =
client::build_client(&cfg).expect("failed to build reqwest::Client for Gemini adapter");
Self {
client,
allow_local: true,
}
}
fn base_url<'a>(&self, ctx: &'a RequestContext) -> &'a str {
ctx.credentials
.base_url
.as_deref()
.unwrap_or(DEFAULT_BASE_URL)
}
}
#[async_trait]
impl Provider for GeminiProvider {
fn id(&self) -> &'static str {
"gemini"
}
fn models(&self) -> Vec<ModelInfo> {
pricing::all_models()
}
fn pricing(&self, model: &str) -> Option<ModelPricing> {
pricing::pricing_for(model)
}
fn dropped_params(&self, req: &tt_shared::ChatCompletionRequest) -> Vec<String> {
let mut out = Vec::new();
if req.n.is_some() {
out.push("n".to_string());
}
if req.seed.is_some() {
out.push("seed".to_string());
}
if req.presence_penalty.is_some() {
out.push("presence_penalty".to_string());
}
if req.frequency_penalty.is_some() {
out.push("frequency_penalty".to_string());
}
if req.user.is_some() {
out.push("user".to_string());
}
out
}
#[instrument(skip(self, ctx), fields(provider = "gemini", model = %req.model))]
async fn chat_completion(
&self,
req: ChatCompletionRequest,
ctx: &RequestContext,
) -> Result<ChatCompletionResponse, ProviderError> {
let base_url = self.base_url(ctx);
if ctx.credentials.base_url.is_some() {
validate_provider_url(base_url, self.allow_local)
.map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
}
let api_key = ctx.credentials.api_key.expose().to_string();
let model = req.model.clone();
translate::validate_model_id(&model)?;
let url = format!("{base_url}/v1beta/models/{model}:generateContent");
let body = translate::translate_request(req)?;
let mut request_builder = self
.client
.post(&url)
.header("Content-Type", "application/json")
.header("x-goog-api-key", &api_key)
.json(&body);
for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
request_builder = request_builder.header(name, value);
}
let response = request_builder
.send()
.await
.map_err(errors::map_reqwest_error)?;
let status = response.status().as_u16();
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let response_text = response.text().await.map_err(errors::map_reqwest_error)?;
if status >= 400 {
return Err(errors::map_response_error(
status,
&response_text,
retry_after.as_deref(),
&model,
));
}
translate::deserialize_response(&response_text, &model)
}
#[instrument(skip(self, ctx), fields(provider = "gemini", model = %req.model))]
async fn chat_completion_stream(
&self,
req: ChatCompletionRequest,
ctx: &RequestContext,
) -> Result<BoxStream<'static, Result<ChatCompletionChunk, ProviderError>>, ProviderError> {
let base_url = self.base_url(ctx);
if ctx.credentials.base_url.is_some() {
validate_provider_url(base_url, self.allow_local)
.map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
}
let base_url = base_url.to_string();
let client = self.client.clone();
stream::stream_chat_completion(client, &base_url, req, ctx).await
}
async fn embeddings(
&self,
_req: EmbeddingsRequest,
_ctx: &RequestContext,
) -> Result<EmbeddingsResponse, ProviderError> {
Err(ProviderError::Unsupported(
"Gemini embedding models use a separate endpoint; use a dedicated embedding adapter"
.to_string(),
))
}
}