rig/providers/gemini/
client.rs

1use super::{
2    completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5    CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, impl_conversion_traits,
6};
7use crate::{
8    Embed,
9    agent::AgentBuilder,
10    embeddings::{self},
11    extractor::ExtractorBuilder,
12};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15
16// ================================================================
17// Google Gemini Client
18// ================================================================
19const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
20
21#[derive(Clone)]
22pub struct Client {
23    base_url: String,
24    api_key: String,
25    default_headers: reqwest::header::HeaderMap,
26    http_client: reqwest::Client,
27}
28
29impl std::fmt::Debug for Client {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("Client")
32            .field("base_url", &self.base_url)
33            .field("http_client", &self.http_client)
34            .field("default_headers", &self.default_headers)
35            .field("api_key", &"<REDACTED>")
36            .finish()
37    }
38}
39
40impl Client {
41    pub fn new(api_key: &str) -> Self {
42        Self::from_url(api_key, GEMINI_API_BASE_URL)
43    }
44    pub fn from_url(api_key: &str, base_url: &str) -> Self {
45        let mut default_headers = reqwest::header::HeaderMap::new();
46        default_headers.insert(
47            reqwest::header::CONTENT_TYPE,
48            "application/json".parse().unwrap(),
49        );
50        Self {
51            base_url: base_url.to_string(),
52            api_key: api_key.to_string(),
53            default_headers,
54            http_client: reqwest::Client::builder()
55                .build()
56                .expect("Gemini reqwest client should build"),
57        }
58    }
59
60    pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
61        // API key gets inserted as query param - no need to add bearer auth or headers
62        let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
63
64        tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
65        self.http_client
66            .post(url)
67            .headers(self.default_headers.clone())
68    }
69
70    /// Use your own `reqwest::Client`.
71    /// The default headers will be automatically attached upon trying to make a request.
72    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
73        self.http_client = client;
74
75        self
76    }
77
78    pub fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
79        let url =
80            format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
81
82        tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
83        self.http_client
84            .post(url)
85            .headers(self.default_headers.clone())
86    }
87
88    /// Create an agent builder with the given completion model.
89    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
90    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
91    /// # Example
92    /// ```
93    /// use rig::providers::gemini::{Client, self};
94    ///
95    /// // Initialize the Google Gemini client
96    /// let gemini = Client::new("your-google-gemini-api-key");
97    ///
98    /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
99    ///    .preamble("You are comedian AI with a mission to make people laugh.")
100    ///    .temperature(0.0)
101    ///    .build();
102    /// ```
103    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
104        AgentBuilder::new(self.completion_model(model))
105    }
106
107    /// Create an extractor builder with the given completion model.
108    pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
109        &self,
110        model: &str,
111    ) -> ExtractorBuilder<T, CompletionModel> {
112        ExtractorBuilder::new(self.completion_model(model))
113    }
114}
115
116impl ProviderClient for Client {
117    /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
118    /// Panics if the environment variable is not set.
119    fn from_env() -> Self {
120        let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
121        Self::new(&api_key)
122    }
123
124    fn from_val(input: crate::client::ProviderValue) -> Self {
125        let crate::client::ProviderValue::Simple(api_key) = input else {
126            panic!("Incorrect provider value type")
127        };
128        Self::new(&api_key)
129    }
130}
131
132impl CompletionClient for Client {
133    type CompletionModel = CompletionModel;
134
135    /// Create a completion model with the given name.
136    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
137    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
138    fn completion_model(&self, model: &str) -> CompletionModel {
139        CompletionModel::new(self.clone(), model)
140    }
141}
142
143impl EmbeddingsClient for Client {
144    type EmbeddingModel = EmbeddingModel;
145
146    /// Create an embedding model with the given name.
147    /// Note: default embedding dimension of 0 will be used if model is not known.
148    /// If this is the case, it's better to use function `embedding_model_with_ndims`
149    ///
150    /// # Example
151    /// ```
152    /// use rig::providers::gemini::{Client, self};
153    ///
154    /// // Initialize the Google Gemini client
155    /// let gemini = Client::new("your-google-gemini-api-key");
156    ///
157    /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
158    /// ```
159    fn embedding_model(&self, model: &str) -> EmbeddingModel {
160        EmbeddingModel::new(self.clone(), model, None)
161    }
162
163    /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
164    ///
165    /// # Example
166    /// ```
167    /// use rig::providers::gemini::{Client, self};
168    ///
169    /// // Initialize the Google Gemini client
170    /// let gemini = Client::new("your-google-gemini-api-key");
171    ///
172    /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
173    /// ```
174    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
175        EmbeddingModel::new(self.clone(), model, Some(ndims))
176    }
177
178    /// Create an embedding builder with the given embedding model.
179    ///
180    /// # Example
181    /// ```
182    /// use rig::providers::gemini::{Client, self};
183    ///
184    /// // Initialize the Google Gemini client
185    /// let gemini = Client::new("your-google-gemini-api-key");
186    ///
187    /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
188    ///     .simple_document("doc0", "Hello, world!")
189    ///     .simple_document("doc1", "Goodbye, world!")
190    ///     .build()
191    ///     .await
192    ///     .expect("Failed to embed documents");
193    /// ```
194    fn embeddings<D: Embed>(
195        &self,
196        model: &str,
197    ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
198        embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
199    }
200}
201
202impl TranscriptionClient for Client {
203    type TranscriptionModel = TranscriptionModel;
204
205    /// Create a transcription model with the given name.
206    /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
207    /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
208    fn transcription_model(&self, model: &str) -> TranscriptionModel {
209        TranscriptionModel::new(self.clone(), model)
210    }
211}
212
213impl_conversion_traits!(
214    AsImageGeneration,
215    AsAudioGeneration for Client
216);
217
218#[derive(Debug, Deserialize)]
219pub struct ApiErrorResponse {
220    pub message: String,
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225pub enum ApiResponse<T> {
226    Ok(T),
227    Err(ApiErrorResponse),
228}