rig/providers/gemini/
client.rs

1use super::{
2    completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6    VerifyClient, VerifyError, impl_conversion_traits,
7};
8use crate::http_client::{self, HttpClientExt};
9use crate::wasm_compat::*;
10use crate::{
11    Embed,
12    embeddings::{self},
13};
14use bytes::Bytes;
15use serde::Deserialize;
16use std::fmt::Debug;
17
18// ================================================================
19// Google Gemini Client
20// ================================================================
21const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
22
23pub struct ClientBuilder<'a, T = reqwest::Client> {
24    api_key: &'a str,
25    base_url: &'a str,
26    http_client: T,
27}
28
29impl<'a, T> ClientBuilder<'a, T>
30where
31    T: HttpClientExt,
32{
33    pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> {
34        ClientBuilder {
35            api_key,
36            base_url: GEMINI_API_BASE_URL,
37            http_client: Default::default(),
38        }
39    }
40
41    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
42        Self {
43            api_key,
44            base_url: GEMINI_API_BASE_URL,
45            http_client,
46        }
47    }
48
49    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U>
50    where
51        U: HttpClientExt,
52    {
53        ClientBuilder {
54            api_key: self.api_key,
55            base_url: self.base_url,
56            http_client,
57        }
58    }
59
60    pub fn base_url(mut self, base_url: &'a str) -> Self {
61        self.base_url = base_url;
62        self
63    }
64
65    pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
66        let mut default_headers = reqwest::header::HeaderMap::new();
67        default_headers.insert(
68            reqwest::header::CONTENT_TYPE,
69            "application/json".parse().unwrap(),
70        );
71
72        Ok(Client {
73            base_url: self.base_url.to_string(),
74            api_key: self.api_key.to_string(),
75            default_headers,
76            http_client: self.http_client,
77        })
78    }
79}
80#[derive(Clone)]
81pub struct Client<T = reqwest::Client> {
82    base_url: String,
83    api_key: String,
84    default_headers: reqwest::header::HeaderMap,
85    http_client: T,
86}
87
88impl<T> Debug for Client<T>
89where
90    T: Debug,
91{
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("Client")
94            .field("base_url", &self.base_url)
95            .field("http_client", &self.http_client)
96            .field("default_headers", &self.default_headers)
97            .field("api_key", &"<REDACTED>")
98            .finish()
99    }
100}
101
102impl<T> Client<T>
103where
104    T: HttpClientExt + Default,
105{
106    /// Create a new Google Gemini client builder.
107    ///
108    /// # Example
109    /// ```
110    /// use rig::providers::gemini::{ClientBuilder, self};
111    ///
112    /// // Initialize the Google Gemini client
113    /// let gemini_client = Client::builder("your-google-gemini-api-key")
114    ///    .build()
115    /// ```
116    pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
117        ClientBuilder::new_with_client(api_key, Default::default())
118    }
119
120    /// Create a new Google Gemini client. For more control, use the `builder` method.
121    ///
122    /// # Panics
123    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
124    pub fn new(api_key: &str) -> Self {
125        Self::builder(api_key)
126            .build()
127            .expect("Gemini client should build")
128    }
129}
130
131impl Client<reqwest::Client> {
132    pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
133        let url = format!(
134            "{}/{}?alt=sse&key={}",
135            self.base_url,
136            path.trim_start_matches('/'),
137            self.api_key
138        );
139
140        tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
141
142        self.http_client
143            .post(url)
144            .headers(self.default_headers.clone())
145    }
146}
147
148impl<T> Client<T>
149where
150    T: HttpClientExt,
151{
152    pub(crate) fn post(&self, path: &str) -> http_client::Builder {
153        // API key gets inserted as query param - no need to add bearer auth or headers
154        let url = format!(
155            "{}/{}?key={}",
156            self.base_url,
157            path.trim_start_matches('/'),
158            self.api_key
159        );
160
161        tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
162        let mut req = http_client::Request::post(url);
163
164        if let Some(hs) = req.headers_mut() {
165            *hs = self.default_headers.clone();
166        }
167
168        req
169    }
170
171    pub(crate) fn get(&self, path: &str) -> http_client::Builder {
172        // API key gets inserted as query param - no need to add bearer auth or headers
173        let url = format!(
174            "{}/{}?key={}",
175            self.base_url,
176            path.trim_start_matches('/'),
177            self.api_key
178        );
179
180        tracing::debug!("GET {}/{}?key={}", self.base_url, path, "****");
181
182        let mut req = http_client::Request::get(url);
183
184        if let Some(hs) = req.headers_mut() {
185            *hs = self.default_headers.clone();
186        }
187
188        req
189    }
190
191    pub(crate) async fn send<U, R>(
192        &self,
193        req: http_client::Request<U>,
194    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
195    where
196        U: Into<Bytes> + Send,
197        R: From<Bytes> + Send + 'static,
198    {
199        self.http_client.send(req).await
200    }
201}
202
203// NOTE: (@FayCarsons) This cannot be implemented for all T because `AsCompletion`/`CompletionModel` requires SSE
204// which we are not able to implement for any `T: HttpClientExt` right now
205impl ProviderClient for Client<reqwest::Client> {
206    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
207    /// Panics if the environment variable is not set.
208    fn from_env() -> Self {
209        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
210        Self::new(&api_key)
211    }
212
213    fn from_val(input: crate::client::ProviderValue) -> Self {
214        let crate::client::ProviderValue::Simple(api_key) = input else {
215            panic!("Incorrect provider value type")
216        };
217        Self::new(&api_key)
218    }
219}
220
221impl CompletionClient for Client<reqwest::Client> {
222    type CompletionModel = CompletionModel<reqwest::Client>;
223
224    /// Create a completion model with the given name.
225    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
226    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
227    fn completion_model(&self, model: &str) -> Self::CompletionModel {
228        CompletionModel::new(self.clone(), model)
229    }
230}
231
232impl<T> EmbeddingsClient for Client<T>
233where
234    T: HttpClientExt + Clone + Debug + Default + 'static,
235    Client<T>: CompletionClient,
236{
237    type EmbeddingModel = EmbeddingModel<T>;
238
239    /// Create an embedding model with the given name.
240    /// Note: default embedding dimension of 0 will be used if model is not known.
241    /// If this is the case, it's better to use function `embedding_model_with_ndims`
242    ///
243    /// # Example
244    /// ```
245    /// use rig::providers::gemini::{Client, self};
246    ///
247    /// // Initialize the Google Gemini client
248    /// let gemini = Client::new("your-google-gemini-api-key");
249    ///
250    /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
251    /// ```
252    fn embedding_model(&self, model: &str) -> EmbeddingModel<T> {
253        EmbeddingModel::new(self.clone(), model, None)
254    }
255
256    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
257    ///
258    /// # Example
259    /// ```
260    /// use rig::providers::gemini::{Client, self};
261    ///
262    /// // Initialize the Google Gemini client
263    /// let gemini = Client::new("your-google-gemini-api-key");
264    ///
265    /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
266    /// ```
267    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel<T> {
268        EmbeddingModel::new(self.clone(), model, Some(ndims))
269    }
270
271    /// Create an embedding builder with the given embedding model.
272    ///
273    /// # Example
274    /// ```
275    /// use rig::providers::gemini::{Client, self};
276    ///
277    /// // Initialize the Google Gemini client
278    /// let gemini = Client::new("your-google-gemini-api-key");
279    ///
280    /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
281    ///     .simple_document("doc0", "Hello, world!")
282    ///     .simple_document("doc1", "Goodbye, world!")
283    ///     .build()
284    ///     .await
285    ///     .expect("Failed to embed documents");
286    /// ```
287    fn embeddings<D: Embed>(
288        &self,
289        model: &str,
290    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel<T>, D> {
291        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
292    }
293}
294
295impl<T> TranscriptionClient for Client<T>
296where
297    T: HttpClientExt + Clone + Debug + Default + 'static,
298    Client<T>: CompletionClient,
299{
300    type TranscriptionModel = TranscriptionModel<T>;
301
302    /// Create a transcription model with the given name.
303    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
304    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
305    fn transcription_model(&self, model: &str) -> TranscriptionModel<T> {
306        TranscriptionModel::new(self.clone(), model)
307    }
308}
309
310impl<T> VerifyClient for Client<T>
311where
312    T: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
313    Client<T>: CompletionClient,
314{
315    #[cfg_attr(feature = "worker", worker::send)]
316    async fn verify(&self) -> Result<(), VerifyError> {
317        let req = self
318            .get("/v1beta/models")
319            .body(http_client::NoBody)
320            .map_err(|e| VerifyError::HttpError(e.into()))?;
321        let response = self.http_client.send::<_, Vec<u8>>(req).await?;
322
323        match response.status() {
324            reqwest::StatusCode::OK => Ok(()),
325            reqwest::StatusCode::FORBIDDEN => Err(VerifyError::InvalidAuthentication),
326            reqwest::StatusCode::INTERNAL_SERVER_ERROR
327            | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
328                let text = http_client::text(response).await?;
329                Err(VerifyError::ProviderError(text))
330            }
331            _ => {
332                // TODO: Find/write some alternative for this that uses `http::StatusCode` vs
333                // reqwest::StatusCode
334                //
335                // response.error_for_status()?;
336                Ok(())
337            }
338        }
339    }
340}
341
342impl_conversion_traits!(
343    AsImageGeneration,
344    AsAudioGeneration for Client<T>
345);
346
347#[derive(Debug, Deserialize)]
348pub struct ApiErrorResponse {
349    pub message: String,
350}
351
352#[derive(Debug, Deserialize)]
353#[serde(untagged)]
354pub enum ApiResponse<T> {
355    Ok(T),
356    Err(ApiErrorResponse),
357}