rig/providers/gemini/client.rs
1use super::{
2 completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5 CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, impl_conversion_traits,
6};
7use crate::{
8 Embed,
9 agent::AgentBuilder,
10 embeddings::{self},
11 extractor::ExtractorBuilder,
12};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15
16// ================================================================
17// Google Gemini Client
18// ================================================================
19const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
20
21#[derive(Clone)]
22pub struct Client {
23 base_url: String,
24 api_key: String,
25 default_headers: reqwest::header::HeaderMap,
26 http_client: reqwest::Client,
27}
28
29impl std::fmt::Debug for Client {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("Client")
32 .field("base_url", &self.base_url)
33 .field("http_client", &self.http_client)
34 .field("default_headers", &self.default_headers)
35 .field("api_key", &"<REDACTED>")
36 .finish()
37 }
38}
39
40impl Client {
41 pub fn new(api_key: &str) -> Self {
42 Self::from_url(api_key, GEMINI_API_BASE_URL)
43 }
44 pub fn from_url(api_key: &str, base_url: &str) -> Self {
45 let mut default_headers = reqwest::header::HeaderMap::new();
46 default_headers.insert(
47 reqwest::header::CONTENT_TYPE,
48 "application/json".parse().unwrap(),
49 );
50 Self {
51 base_url: base_url.to_string(),
52 api_key: api_key.to_string(),
53 default_headers,
54 http_client: reqwest::Client::builder()
55 .build()
56 .expect("Gemini reqwest client should build"),
57 }
58 }
59
60 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
61 // API key gets inserted as query param - no need to add bearer auth or headers
62 let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
63
64 tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
65 self.http_client
66 .post(url)
67 .headers(self.default_headers.clone())
68 }
69
70 /// Use your own `reqwest::Client`.
71 /// The default headers will be automatically attached upon trying to make a request.
72 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
73 self.http_client = client;
74
75 self
76 }
77
78 pub fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
79 let url =
80 format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
81
82 tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
83 self.http_client
84 .post(url)
85 .headers(self.default_headers.clone())
86 }
87
88 /// Create an agent builder with the given completion model.
89 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
90 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
91 /// # Example
92 /// ```
93 /// use rig::providers::gemini::{Client, self};
94 ///
95 /// // Initialize the Google Gemini client
96 /// let gemini = Client::new("your-google-gemini-api-key");
97 ///
98 /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
99 /// .preamble("You are comedian AI with a mission to make people laugh.")
100 /// .temperature(0.0)
101 /// .build();
102 /// ```
103 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
104 AgentBuilder::new(self.completion_model(model))
105 }
106
107 /// Create an extractor builder with the given completion model.
108 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
109 &self,
110 model: &str,
111 ) -> ExtractorBuilder<T, CompletionModel> {
112 ExtractorBuilder::new(self.completion_model(model))
113 }
114}
115
116impl ProviderClient for Client {
117 /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
118 /// Panics if the environment variable is not set.
119 fn from_env() -> Self {
120 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
121 Self::new(&api_key)
122 }
123}
124
125impl CompletionClient for Client {
126 type CompletionModel = CompletionModel;
127
128 /// Create a completion model with the given name.
129 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
130 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
131 fn completion_model(&self, model: &str) -> CompletionModel {
132 CompletionModel::new(self.clone(), model)
133 }
134}
135
136impl EmbeddingsClient for Client {
137 type EmbeddingModel = EmbeddingModel;
138
139 /// Create an embedding model with the given name.
140 /// Note: default embedding dimension of 0 will be used if model is not known.
141 /// If this is the case, it's better to use function `embedding_model_with_ndims`
142 ///
143 /// # Example
144 /// ```
145 /// use rig::providers::gemini::{Client, self};
146 ///
147 /// // Initialize the Google Gemini client
148 /// let gemini = Client::new("your-google-gemini-api-key");
149 ///
150 /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
151 /// ```
152 fn embedding_model(&self, model: &str) -> EmbeddingModel {
153 EmbeddingModel::new(self.clone(), model, None)
154 }
155
156 /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
157 ///
158 /// # Example
159 /// ```
160 /// use rig::providers::gemini::{Client, self};
161 ///
162 /// // Initialize the Google Gemini client
163 /// let gemini = Client::new("your-google-gemini-api-key");
164 ///
165 /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
166 /// ```
167 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
168 EmbeddingModel::new(self.clone(), model, Some(ndims))
169 }
170
171 /// Create an embedding builder with the given embedding model.
172 ///
173 /// # Example
174 /// ```
175 /// use rig::providers::gemini::{Client, self};
176 ///
177 /// // Initialize the Google Gemini client
178 /// let gemini = Client::new("your-google-gemini-api-key");
179 ///
180 /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
181 /// .simple_document("doc0", "Hello, world!")
182 /// .simple_document("doc1", "Goodbye, world!")
183 /// .build()
184 /// .await
185 /// .expect("Failed to embed documents");
186 /// ```
187 fn embeddings<D: Embed>(
188 &self,
189 model: &str,
190 ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
191 embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
192 }
193}
194
195impl TranscriptionClient for Client {
196 type TranscriptionModel = TranscriptionModel;
197
198 /// Create a transcription model with the given name.
199 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
200 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
201 fn transcription_model(&self, model: &str) -> TranscriptionModel {
202 TranscriptionModel::new(self.clone(), model)
203 }
204}
205
206impl_conversion_traits!(
207 AsImageGeneration,
208 AsAudioGeneration for Client
209);
210
211#[derive(Debug, Deserialize)]
212pub struct ApiErrorResponse {
213 pub message: String,
214}
215
216#[derive(Debug, Deserialize)]
217#[serde(untagged)]
218pub enum ApiResponse<T> {
219 Ok(T),
220 Err(ApiErrorResponse),
221}