rig/providers/gemini/
client.rs

1use super::{
2    completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6    VerifyClient, VerifyError, impl_conversion_traits,
7};
8use crate::http_client::{self, HttpClientExt};
9use crate::wasm_compat::*;
10use crate::{
11    Embed,
12    embeddings::{self},
13};
14use bytes::Bytes;
15use serde::Deserialize;
16use std::fmt::Debug;
17
18// ================================================================
19// Google Gemini Client
20// ================================================================
21const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
22
23pub struct ClientBuilder<'a, T = reqwest::Client> {
24    api_key: &'a str,
25    base_url: &'a str,
26    http_client: T,
27}
28
29impl<'a, T> ClientBuilder<'a, T>
30where
31    T: HttpClientExt + Default,
32{
33    pub fn new(api_key: &'a str) -> ClientBuilder<'a, T> {
34        ClientBuilder {
35            api_key,
36            base_url: GEMINI_API_BASE_URL,
37            http_client: Default::default(),
38        }
39    }
40
41    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
42        Self {
43            api_key,
44            base_url: GEMINI_API_BASE_URL,
45            http_client,
46        }
47    }
48
49    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U>
50    where
51        U: HttpClientExt,
52    {
53        ClientBuilder {
54            api_key: self.api_key,
55            base_url: self.base_url,
56            http_client,
57        }
58    }
59
60    pub fn base_url(mut self, base_url: &'a str) -> Self {
61        self.base_url = base_url;
62        self
63    }
64
65    pub fn build(self) -> Result<Client<T>, ClientBuilderError> {
66        let mut default_headers = reqwest::header::HeaderMap::new();
67        default_headers.insert(
68            reqwest::header::CONTENT_TYPE,
69            "application/json".parse().unwrap(),
70        );
71
72        Ok(Client {
73            base_url: self.base_url.to_string(),
74            api_key: self.api_key.to_string(),
75            default_headers,
76            http_client: self.http_client,
77        })
78    }
79}
80#[derive(Clone)]
81pub struct Client<T = reqwest::Client> {
82    base_url: String,
83    api_key: String,
84    default_headers: reqwest::header::HeaderMap,
85    pub http_client: T,
86}
87
88impl<T> Debug for Client<T>
89where
90    T: Debug,
91{
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("Client")
94            .field("base_url", &self.base_url)
95            .field("http_client", &self.http_client)
96            .field("default_headers", &self.default_headers)
97            .field("api_key", &"<REDACTED>")
98            .finish()
99    }
100}
101
102impl<T> Client<T>
103where
104    T: HttpClientExt,
105{
106    pub(crate) fn post(&self, path: &str) -> http_client::Builder {
107        // API key gets inserted as query param - no need to add bearer auth or headers
108        let url = format!(
109            "{}/{}?key={}",
110            self.base_url,
111            path.trim_start_matches('/'),
112            self.api_key
113        );
114
115        tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
116        let mut req = http_client::Request::post(url);
117
118        if let Some(hs) = req.headers_mut() {
119            *hs = self.default_headers.clone();
120        }
121
122        req
123    }
124
125    pub(crate) fn post_sse(&self, path: &str) -> http_client::Builder {
126        let url = format!(
127            "{}/{}?alt=sse&key={}",
128            self.base_url,
129            path.trim_start_matches('/'),
130            self.api_key
131        );
132
133        tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
134
135        let mut req = http_client::Request::post(url);
136
137        if let Some(hs) = req.headers_mut() {
138            *hs = self.default_headers.clone();
139        }
140
141        req
142    }
143
144    pub(crate) fn get(&self, path: &str) -> http_client::Builder {
145        // API key gets inserted as query param - no need to add bearer auth or headers
146        let url = format!(
147            "{}/{}?key={}",
148            self.base_url,
149            path.trim_start_matches('/'),
150            self.api_key
151        );
152
153        tracing::debug!("GET {}/{}?key={}", self.base_url, path, "****");
154
155        let mut req = http_client::Request::get(url);
156
157        if let Some(hs) = req.headers_mut() {
158            *hs = self.default_headers.clone();
159        }
160
161        req
162    }
163
164    pub(crate) async fn send<U, R>(
165        &self,
166        req: http_client::Request<U>,
167    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
168    where
169        U: Into<Bytes> + Send,
170        R: From<Bytes> + Send + 'static,
171    {
172        self.http_client.send(req).await
173    }
174}
175
176impl Client<reqwest::Client> {
177    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
178        ClientBuilder::<reqwest::Client>::new(api_key)
179    }
180
181    /// Create a new Gemini client. For more control, use the `builder` method.
182    ///
183    /// # Panics
184    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
185    pub fn new(api_key: &str) -> Self {
186        ClientBuilder::<reqwest::Client>::new(api_key)
187            .build()
188            .unwrap()
189    }
190
191    pub fn from_env() -> Self {
192        <Self as ProviderClient>::from_env()
193    }
194}
195
196impl<T> ProviderClient for Client<T>
197where
198    T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
199{
200    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
201    /// Panics if the environment variable is not set.
202    fn from_env() -> Self {
203        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
204        ClientBuilder::<T>::new(&api_key).build().unwrap()
205    }
206
207    fn from_val(input: crate::client::ProviderValue) -> Self {
208        let crate::client::ProviderValue::Simple(api_key) = input else {
209            panic!("Incorrect provider value type")
210        };
211        ClientBuilder::<T>::new(&api_key).build().unwrap()
212    }
213}
214
215impl<T> CompletionClient for Client<T>
216where
217    T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
218{
219    type CompletionModel = CompletionModel<T>;
220
221    /// Create a completion model with the given name.
222    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
223    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
224    fn completion_model(&self, model: &str) -> Self::CompletionModel {
225        CompletionModel::new(self.clone(), model)
226    }
227}
228
229impl<T> EmbeddingsClient for Client<T>
230where
231    T: HttpClientExt + Clone + Debug + Default + 'static,
232    Client<T>: CompletionClient,
233{
234    type EmbeddingModel = EmbeddingModel<T>;
235
236    /// Create an embedding model with the given name.
237    /// Note: default embedding dimension of 0 will be used if model is not known.
238    /// If this is the case, it's better to use function `embedding_model_with_ndims`
239    ///
240    /// # Example
241    /// ```
242    /// use rig::providers::gemini::{Client, self};
243    ///
244    /// // Initialize the Google Gemini client
245    /// let gemini = Client::new("your-google-gemini-api-key");
246    ///
247    /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
248    /// ```
249    fn embedding_model(&self, model: &str) -> EmbeddingModel<T> {
250        EmbeddingModel::new(self.clone(), model, None)
251    }
252
253    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
254    ///
255    /// # Example
256    /// ```
257    /// use rig::providers::gemini::{Client, self};
258    ///
259    /// // Initialize the Google Gemini client
260    /// let gemini = Client::new("your-google-gemini-api-key");
261    ///
262    /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
263    /// ```
264    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel<T> {
265        EmbeddingModel::new(self.clone(), model, Some(ndims))
266    }
267
268    /// Create an embedding builder with the given embedding model.
269    ///
270    /// # Example
271    /// ```
272    /// use rig::providers::gemini::{Client, self};
273    ///
274    /// // Initialize the Google Gemini client
275    /// let gemini = Client::new("your-google-gemini-api-key");
276    ///
277    /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
278    ///     .simple_document("doc0", "Hello, world!")
279    ///     .simple_document("doc1", "Goodbye, world!")
280    ///     .build()
281    ///     .await
282    ///     .expect("Failed to embed documents");
283    /// ```
284    fn embeddings<D: Embed>(
285        &self,
286        model: &str,
287    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel<T>, D> {
288        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
289    }
290}
291
292impl<T> TranscriptionClient for Client<T>
293where
294    T: HttpClientExt + Clone + Debug + Default + WasmCompatSend + 'static,
295    Client<T>: CompletionClient,
296{
297    type TranscriptionModel = TranscriptionModel<T>;
298
299    /// Create a transcription model with the given name.
300    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
301    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
302    fn transcription_model(&self, model: &str) -> TranscriptionModel<T> {
303        TranscriptionModel::new(self.clone(), model)
304    }
305}
306
307impl<T> VerifyClient for Client<T>
308where
309    T: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static,
310    Client<T>: CompletionClient,
311{
312    #[cfg_attr(feature = "worker", worker::send)]
313    async fn verify(&self) -> Result<(), VerifyError> {
314        let req = self
315            .get("/v1beta/models")
316            .body(http_client::NoBody)
317            .map_err(|e| VerifyError::HttpError(e.into()))?;
318        let response = self.http_client.send::<_, Vec<u8>>(req).await?;
319
320        match response.status() {
321            reqwest::StatusCode::OK => Ok(()),
322            reqwest::StatusCode::FORBIDDEN => Err(VerifyError::InvalidAuthentication),
323            reqwest::StatusCode::INTERNAL_SERVER_ERROR
324            | reqwest::StatusCode::SERVICE_UNAVAILABLE => {
325                let text = http_client::text(response).await?;
326                Err(VerifyError::ProviderError(text))
327            }
328            _ => {
329                // TODO: Find/write some alternative for this that uses `http::StatusCode` vs
330                // reqwest::StatusCode
331                //
332                // response.error_for_status()?;
333                Ok(())
334            }
335        }
336    }
337}
338
339impl_conversion_traits!(
340    AsImageGeneration,
341    AsAudioGeneration for Client<T>
342);
343
344#[derive(Debug, Deserialize)]
345pub struct ApiErrorResponse {
346    pub message: String,
347}
348
349#[derive(Debug, Deserialize)]
350#[serde(untagged)]
351pub enum ApiResponse<T> {
352    Ok(T),
353    Err(ApiErrorResponse),
354}