rig/providers/gemini/
client.rs

1use super::{
2    completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6    VerifyClient, VerifyError, impl_conversion_traits,
7};
8use crate::{
9    Embed,
10    embeddings::{self},
11};
12use serde::Deserialize;
13
14// ================================================================
15// Google Gemini Client
16// ================================================================
17const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
18
19pub struct ClientBuilder<'a> {
20    api_key: &'a str,
21    base_url: &'a str,
22    http_client: Option<reqwest::Client>,
23}
24
25impl<'a> ClientBuilder<'a> {
26    pub fn new(api_key: &'a str) -> Self {
27        Self {
28            api_key,
29            base_url: GEMINI_API_BASE_URL,
30            http_client: None,
31        }
32    }
33
34    pub fn base_url(mut self, base_url: &'a str) -> Self {
35        self.base_url = base_url;
36        self
37    }
38
39    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
40        self.http_client = Some(client);
41        self
42    }
43
44    pub fn build(self) -> Result<Client, ClientBuilderError> {
45        let mut default_headers = reqwest::header::HeaderMap::new();
46        default_headers.insert(
47            reqwest::header::CONTENT_TYPE,
48            "application/json".parse().unwrap(),
49        );
50        let http_client = if let Some(http_client) = self.http_client {
51            http_client
52        } else {
53            reqwest::Client::builder().build()?
54        };
55
56        Ok(Client {
57            base_url: self.base_url.to_string(),
58            api_key: self.api_key.to_string(),
59            default_headers,
60            http_client,
61        })
62    }
63}
64#[derive(Clone)]
65pub struct Client {
66    base_url: String,
67    api_key: String,
68    default_headers: reqwest::header::HeaderMap,
69    http_client: reqwest::Client,
70}
71
72impl std::fmt::Debug for Client {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("Client")
75            .field("base_url", &self.base_url)
76            .field("http_client", &self.http_client)
77            .field("default_headers", &self.default_headers)
78            .field("api_key", &"<REDACTED>")
79            .finish()
80    }
81}
82
83impl Client {
84    /// Create a new Google Gemini client builder.
85    ///
86    /// # Example
87    /// ```
88    /// use rig::providers::gemini::{ClientBuilder, self};
89    ///
90    /// // Initialize the Google Gemini client
91    /// let gemini_client = Client::builder("your-google-gemini-api-key")
92    ///    .build()
93    /// ```
94    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
95        ClientBuilder::new(api_key)
96    }
97
98    /// Create a new Google Gemini client. For more control, use the `builder` method.
99    ///
100    /// # Panics
101    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
102    pub fn new(api_key: &str) -> Self {
103        Self::builder(api_key)
104            .build()
105            .expect("Gemini client should build")
106    }
107
108    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
109        // API key gets inserted as query param - no need to add bearer auth or headers
110        let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
111
112        tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
113        self.http_client
114            .post(url)
115            .headers(self.default_headers.clone())
116    }
117
118    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
119        // API key gets inserted as query param - no need to add bearer auth or headers
120        let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
121
122        tracing::debug!("GET {}/{}?key={}", self.base_url, path, "****");
123        self.http_client
124            .get(url)
125            .headers(self.default_headers.clone())
126    }
127
128    pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
129        let url =
130            format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
131
132        tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
133        self.http_client
134            .post(url)
135            .headers(self.default_headers.clone())
136    }
137}
138
139impl ProviderClient for Client {
140    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
141    /// Panics if the environment variable is not set.
142    fn from_env() -> Self {
143        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
144        Self::new(&api_key)
145    }
146
147    fn from_val(input: crate::client::ProviderValue) -> Self {
148        let crate::client::ProviderValue::Simple(api_key) = input else {
149            panic!("Incorrect provider value type")
150        };
151        Self::new(&api_key)
152    }
153}
154
155impl CompletionClient for Client {
156    type CompletionModel = CompletionModel;
157
158    /// Create a completion model with the given name.
159    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
160    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
161    fn completion_model(&self, model: &str) -> CompletionModel {
162        CompletionModel::new(self.clone(), model)
163    }
164}
165
166impl EmbeddingsClient for Client {
167    type EmbeddingModel = EmbeddingModel;
168
169    /// Create an embedding model with the given name.
170    /// Note: default embedding dimension of 0 will be used if model is not known.
171    /// If this is the case, it's better to use function `embedding_model_with_ndims`
172    ///
173    /// # Example
174    /// ```
175    /// use rig::providers::gemini::{Client, self};
176    ///
177    /// // Initialize the Google Gemini client
178    /// let gemini = Client::new("your-google-gemini-api-key");
179    ///
180    /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
181    /// ```
182    fn embedding_model(&self, model: &str) -> EmbeddingModel {
183        EmbeddingModel::new(self.clone(), model, None)
184    }
185
186    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
187    ///
188    /// # Example
189    /// ```
190    /// use rig::providers::gemini::{Client, self};
191    ///
192    /// // Initialize the Google Gemini client
193    /// let gemini = Client::new("your-google-gemini-api-key");
194    ///
195    /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
196    /// ```
197    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
198        EmbeddingModel::new(self.clone(), model, Some(ndims))
199    }
200
201    /// Create an embedding builder with the given embedding model.
202    ///
203    /// # Example
204    /// ```
205    /// use rig::providers::gemini::{Client, self};
206    ///
207    /// // Initialize the Google Gemini client
208    /// let gemini = Client::new("your-google-gemini-api-key");
209    ///
210    /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
211    ///     .simple_document("doc0", "Hello, world!")
212    ///     .simple_document("doc1", "Goodbye, world!")
213    ///     .build()
214    ///     .await
215    ///     .expect("Failed to embed documents");
216    /// ```
217    fn embeddings<D: Embed>(
218        &self,
219        model: &str,
220    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
221        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
222    }
223}
224
225impl TranscriptionClient for Client {
226    type TranscriptionModel = TranscriptionModel;
227
228    /// Create a transcription model with the given name.
229    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
230    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
231    fn transcription_model(&self, model: &str) -> TranscriptionModel {
232        TranscriptionModel::new(self.clone(), model)
233    }
234}
235
236impl VerifyClient for Client {
237    #[cfg_attr(feature = "worker", worker::send)]
238    async fn verify(&self) -> Result<(), VerifyError> {
239        let response = self.get("/v1beta/models").send().await?;
240        match response.status() {
241            reqwest::StatusCode::OK => Ok(()),
242            reqwest::StatusCode::FORBIDDEN => Err(VerifyError::InvalidAuthentication),
243            reqwest::StatusCode::INTERNAL_SERVER_ERROR
244            | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
245                Err(VerifyError::ProviderError(response.text().await?))
246            }
247            _ => {
248                response.error_for_status()?;
249                Ok(())
250            }
251        }
252    }
253}
254
255impl_conversion_traits!(
256    AsImageGeneration,
257    AsAudioGeneration for Client
258);
259
260#[derive(Debug, Deserialize)]
261pub struct ApiErrorResponse {
262    pub message: String,
263}
264
265#[derive(Debug, Deserialize)]
266#[serde(untagged)]
267pub enum ApiResponse<T> {
268    Ok(T),
269    Err(ApiErrorResponse),
270}