use crate::client::{
self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
Transport,
};
use crate::http_client;
use crate::providers::gemini::model_listing::{GeminiInteractionsModelLister, GeminiModelLister};
use serde::Deserialize;
use std::fmt::Debug;
#[cfg(any(feature = "image", feature = "audio"))]
use crate::client::Nothing;
const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
#[derive(Debug, Default, Clone)]
pub struct GeminiExt {
api_key: String,
}
#[derive(Debug, Default, Clone)]
pub struct GeminiBuilder;
#[derive(Debug, Default, Clone)]
pub struct GeminiInteractionsExt {
api_key: String,
}
#[derive(Debug, Default, Clone)]
pub struct GeminiInteractionsBuilder;
pub struct GeminiApiKey(String);
impl<S> From<S> for GeminiApiKey
where
S: Into<String>,
{
fn from(value: S) -> Self {
Self(value.into())
}
}
pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
impl ApiKey for GeminiApiKey {}
impl DebugExt for GeminiExt {
fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
std::iter::once(("api_key", (&"******") as &dyn Debug))
}
}
impl DebugExt for GeminiInteractionsExt {
fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
std::iter::once(("api_key", (&"******") as &dyn Debug))
}
}
impl Provider for GeminiExt {
type Builder = GeminiBuilder;
const VERIFY_PATH: &'static str = "/v1beta/models";
fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
let trimmed = path.trim_start_matches('/');
let separator = if trimmed.contains('?') { "&" } else { "?" };
match transport {
Transport::Sse => format!(
"{base_url}/{trimmed}{separator}alt=sse&key={}",
self.api_key
),
_ => format!("{base_url}/{trimmed}{separator}key={}", self.api_key),
}
}
}
impl Provider for GeminiInteractionsExt {
type Builder = GeminiInteractionsBuilder;
const VERIFY_PATH: &'static str = "/v1beta/models";
fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
let trimmed = path.trim_start_matches('/');
match transport {
Transport::Sse => {
if trimmed.contains('?') {
format!("{}/{}&alt=sse", base_url, trimmed)
} else {
format!("{}/{}?alt=sse", base_url, trimmed)
}
}
_ => format!("{}/{}", base_url, trimmed),
}
}
fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
Ok(req.header("x-goog-api-key", self.api_key.clone()))
}
}
impl<H> Capabilities<H> for GeminiExt {
type Completion = Capable<super::completion::CompletionModel>;
type Embeddings = Capable<super::embedding::EmbeddingModel>;
type Transcription = Capable<super::transcription::TranscriptionModel>;
type ModelListing = Capable<GeminiModelLister<H>>;
#[cfg(feature = "image")]
type ImageGeneration = Nothing;
#[cfg(feature = "audio")]
type AudioGeneration = Nothing;
}
impl<H> Capabilities<H> for GeminiInteractionsExt {
type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
type Embeddings = Capable<super::embedding::EmbeddingModel>;
type Transcription = Capable<super::transcription::TranscriptionModel>;
type ModelListing = Capable<GeminiInteractionsModelLister<H>>;
#[cfg(feature = "image")]
type ImageGeneration = Nothing;
#[cfg(feature = "audio")]
type AudioGeneration = Nothing;
}
impl ProviderBuilder for GeminiBuilder {
type Extension<H>
= GeminiExt
where
H: http_client::HttpClientExt;
type ApiKey = GeminiApiKey;
const BASE_URL: &'static str = GEMINI_API_BASE_URL;
fn build<H>(
builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
) -> http_client::Result<Self::Extension<H>>
where
H: http_client::HttpClientExt,
{
Ok(GeminiExt {
api_key: builder.get_api_key().0.clone(),
})
}
}
impl ProviderBuilder for GeminiInteractionsBuilder {
type Extension<H>
= GeminiInteractionsExt
where
H: http_client::HttpClientExt;
type ApiKey = GeminiApiKey;
const BASE_URL: &'static str = GEMINI_API_BASE_URL;
fn build<H>(
builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
) -> http_client::Result<Self::Extension<H>>
where
H: http_client::HttpClientExt,
{
Ok(GeminiInteractionsExt {
api_key: builder.get_api_key().0.clone(),
})
}
}
impl ProviderClient for Client {
type Input = GeminiApiKey;
fn from_env() -> Self {
let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
Self::new(api_key).unwrap()
}
fn from_val(input: Self::Input) -> Self {
Self::new(input).unwrap()
}
}
impl ProviderClient for InteractionsClient {
type Input = GeminiApiKey;
fn from_env() -> Self {
let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
Self::new(api_key).unwrap()
}
fn from_val(input: Self::Input) -> Self {
Self::new(input).unwrap()
}
}
impl<H> Client<H> {
pub fn interactions_api(self) -> InteractionsClient<H> {
let api_key = self.ext().api_key.clone();
self.with_ext(GeminiInteractionsExt { api_key })
}
}
impl<H> InteractionsClient<H> {
pub fn generate_content_api(self) -> Client<H> {
let api_key = self.ext().api_key.clone();
self.with_ext(GeminiExt { api_key })
}
}
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_initialization() {
let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
let _client_from_builder: Client = Client::builder()
.api_key("dummy-key")
.build()
.expect("Client::builder() failed");
}
}