rig/providers/gemini/
client.rs

1use super::{
2    completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6    impl_conversion_traits,
7};
8use crate::{
9    Embed,
10    agent::AgentBuilder,
11    embeddings::{self},
12    extractor::ExtractorBuilder,
13};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16
17// ================================================================
18// Google Gemini Client
19// ================================================================
20const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
21
22pub struct ClientBuilder<'a> {
23    api_key: &'a str,
24    base_url: &'a str,
25    http_client: Option<reqwest::Client>,
26}
27
28impl<'a> ClientBuilder<'a> {
29    pub fn new(api_key: &'a str) -> Self {
30        Self {
31            api_key,
32            base_url: GEMINI_API_BASE_URL,
33            http_client: None,
34        }
35    }
36
37    pub fn base_url(mut self, base_url: &'a str) -> Self {
38        self.base_url = base_url;
39        self
40    }
41
42    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
43        self.http_client = Some(client);
44        self
45    }
46
47    pub fn build(self) -> Result<Client, ClientBuilderError> {
48        let mut default_headers = reqwest::header::HeaderMap::new();
49        default_headers.insert(
50            reqwest::header::CONTENT_TYPE,
51            "application/json".parse().unwrap(),
52        );
53        let http_client = if let Some(http_client) = self.http_client {
54            http_client
55        } else {
56            reqwest::Client::builder().build()?
57        };
58
59        Ok(Client {
60            base_url: self.base_url.to_string(),
61            api_key: self.api_key.to_string(),
62            default_headers,
63            http_client,
64        })
65    }
66}
67#[derive(Clone)]
68pub struct Client {
69    base_url: String,
70    api_key: String,
71    default_headers: reqwest::header::HeaderMap,
72    http_client: reqwest::Client,
73}
74
75impl std::fmt::Debug for Client {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("Client")
78            .field("base_url", &self.base_url)
79            .field("http_client", &self.http_client)
80            .field("default_headers", &self.default_headers)
81            .field("api_key", &"<REDACTED>")
82            .finish()
83    }
84}
85
86impl Client {
87    /// Create a new Google Gemini client builder.
88    ///
89    /// # Example
90    /// ```
91    /// use rig::providers::gemini::{ClientBuilder, self};
92    ///
93    /// // Initialize the Google Gemini client
94    /// let gemini_client = Client::builder("your-google-gemini-api-key")
95    ///    .build()
96    /// ```
97    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
98        ClientBuilder::new(api_key)
99    }
100
101    /// Create a new Google Gemini client. For more control, use the `builder` method.
102    ///
103    /// # Panics
104    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
105    pub fn new(api_key: &str) -> Self {
106        Self::builder(api_key)
107            .build()
108            .expect("Gemini client should build")
109    }
110
111    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
112        // API key gets inserted as query param - no need to add bearer auth or headers
113        let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
114
115        tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
116        self.http_client
117            .post(url)
118            .headers(self.default_headers.clone())
119    }
120
121    pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
122        let url =
123            format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
124
125        tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
126        self.http_client
127            .post(url)
128            .headers(self.default_headers.clone())
129    }
130
131    /// Create an agent builder with the given completion model.
132    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
133    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
134    /// # Example
135    /// ```
136    /// use rig::providers::gemini::{Client, self};
137    ///
138    /// // Initialize the Google Gemini client
139    /// let gemini = Client::new("your-google-gemini-api-key");
140    ///
141    /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
142    ///    .preamble("You are comedian AI with a mission to make people laugh.")
143    ///    .temperature(0.0)
144    ///    .build();
145    /// ```
146    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
147        AgentBuilder::new(self.completion_model(model))
148    }
149
150    /// Create an extractor builder with the given completion model.
151    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
152        &self,
153        model: &str,
154    ) -> ExtractorBuilder<T, CompletionModel> {
155        ExtractorBuilder::new(self.completion_model(model))
156    }
157}
158
159impl ProviderClient for Client {
160    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
161    /// Panics if the environment variable is not set.
162    fn from_env() -> Self {
163        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
164        Self::new(&api_key)
165    }
166
167    fn from_val(input: crate::client::ProviderValue) -> Self {
168        let crate::client::ProviderValue::Simple(api_key) = input else {
169            panic!("Incorrect provider value type")
170        };
171        Self::new(&api_key)
172    }
173}
174
175impl CompletionClient for Client {
176    type CompletionModel = CompletionModel;
177
178    /// Create a completion model with the given name.
179    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
180    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
181    fn completion_model(&self, model: &str) -> CompletionModel {
182        CompletionModel::new(self.clone(), model)
183    }
184}
185
186impl EmbeddingsClient for Client {
187    type EmbeddingModel = EmbeddingModel;
188
189    /// Create an embedding model with the given name.
190    /// Note: default embedding dimension of 0 will be used if model is not known.
191    /// If this is the case, it's better to use function `embedding_model_with_ndims`
192    ///
193    /// # Example
194    /// ```
195    /// use rig::providers::gemini::{Client, self};
196    ///
197    /// // Initialize the Google Gemini client
198    /// let gemini = Client::new("your-google-gemini-api-key");
199    ///
200    /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
201    /// ```
202    fn embedding_model(&self, model: &str) -> EmbeddingModel {
203        EmbeddingModel::new(self.clone(), model, None)
204    }
205
206    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
207    ///
208    /// # Example
209    /// ```
210    /// use rig::providers::gemini::{Client, self};
211    ///
212    /// // Initialize the Google Gemini client
213    /// let gemini = Client::new("your-google-gemini-api-key");
214    ///
215    /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
216    /// ```
217    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
218        EmbeddingModel::new(self.clone(), model, Some(ndims))
219    }
220
221    /// Create an embedding builder with the given embedding model.
222    ///
223    /// # Example
224    /// ```
225    /// use rig::providers::gemini::{Client, self};
226    ///
227    /// // Initialize the Google Gemini client
228    /// let gemini = Client::new("your-google-gemini-api-key");
229    ///
230    /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
231    ///     .simple_document("doc0", "Hello, world!")
232    ///     .simple_document("doc1", "Goodbye, world!")
233    ///     .build()
234    ///     .await
235    ///     .expect("Failed to embed documents");
236    /// ```
237    fn embeddings<D: Embed>(
238        &self,
239        model: &str,
240    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
241        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
242    }
243}
244
245impl TranscriptionClient for Client {
246    type TranscriptionModel = TranscriptionModel;
247
248    /// Create a transcription model with the given name.
249    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
250    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
251    fn transcription_model(&self, model: &str) -> TranscriptionModel {
252        TranscriptionModel::new(self.clone(), model)
253    }
254}
255
256impl_conversion_traits!(
257    AsImageGeneration,
258    AsAudioGeneration for Client
259);
260
261#[derive(Debug, Deserialize)]
262pub struct ApiErrorResponse {
263    pub message: String,
264}
265
266#[derive(Debug, Deserialize)]
267#[serde(untagged)]
268pub enum ApiResponse<T> {
269    Ok(T),
270    Err(ApiErrorResponse),
271}