rig/providers/gemini/client.rs
1use super::{
2 completion::CompletionModel, embedding::EmbeddingModel, transcription::TranscriptionModel,
3};
4use crate::client::{
5 impl_conversion_traits, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient,
6};
7use crate::{
8 agent::AgentBuilder,
9 embeddings::{self},
10 extractor::ExtractorBuilder,
11 Embed,
12};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15
16// ================================================================
17// Google Gemini Client
18// ================================================================
19const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
20
21#[derive(Debug, Clone)]
22pub struct Client {
23 base_url: String,
24 api_key: String,
25 http_client: reqwest::Client,
26}
27
28impl Client {
29 pub fn new(api_key: &str) -> Self {
30 Self::from_url(api_key, GEMINI_API_BASE_URL)
31 }
32 pub fn from_url(api_key: &str, base_url: &str) -> Self {
33 Self {
34 base_url: base_url.to_string(),
35 api_key: api_key.to_string(),
36 http_client: reqwest::Client::builder()
37 .default_headers({
38 let mut headers = reqwest::header::HeaderMap::new();
39 headers.insert(
40 reqwest::header::CONTENT_TYPE,
41 "application/json".parse().unwrap(),
42 );
43 headers
44 })
45 .build()
46 .expect("Gemini reqwest client should build"),
47 }
48 }
49
50 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
51 let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/");
52
53 tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****");
54 self.http_client.post(url)
55 }
56
57 pub fn post_sse(&self, path: &str) -> reqwest::RequestBuilder {
58 let url =
59 format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/");
60
61 tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****");
62 self.http_client.post(url)
63 }
64
65 /// Create an agent builder with the given completion model.
66 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
67 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
68 /// # Example
69 /// ```
70 /// use rig::providers::gemini::{Client, self};
71 ///
72 /// // Initialize the Google Gemini client
73 /// let gemini = Client::new("your-google-gemini-api-key");
74 ///
75 /// let agent = gemini.agent(gemini::completion::GEMINI_1_5_PRO)
76 /// .preamble("You are comedian AI with a mission to make people laugh.")
77 /// .temperature(0.0)
78 /// .build();
79 /// ```
80 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
81 AgentBuilder::new(self.completion_model(model))
82 }
83
84 /// Create an extractor builder with the given completion model.
85 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
86 &self,
87 model: &str,
88 ) -> ExtractorBuilder<T, CompletionModel> {
89 ExtractorBuilder::new(self.completion_model(model))
90 }
91}
92
93impl ProviderClient for Client {
94 /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable.
95 /// Panics if the environment variable is not set.
96 fn from_env() -> Self {
97 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
98 Self::new(&api_key)
99 }
100}
101
102impl CompletionClient for Client {
103 type CompletionModel = CompletionModel;
104
105 /// Create a completion model with the given name.
106 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
107 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
108 fn completion_model(&self, model: &str) -> CompletionModel {
109 CompletionModel::new(self.clone(), model)
110 }
111}
112
113impl EmbeddingsClient for Client {
114 type EmbeddingModel = EmbeddingModel;
115
116 /// Create an embedding model with the given name.
117 /// Note: default embedding dimension of 0 will be used if model is not known.
118 /// If this is the case, it's better to use function `embedding_model_with_ndims`
119 ///
120 /// # Example
121 /// ```
122 /// use rig::providers::gemini::{Client, self};
123 ///
124 /// // Initialize the Google Gemini client
125 /// let gemini = Client::new("your-google-gemini-api-key");
126 ///
127 /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001);
128 /// ```
129 fn embedding_model(&self, model: &str) -> EmbeddingModel {
130 EmbeddingModel::new(self.clone(), model, None)
131 }
132
133 /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model.
134 ///
135 /// # Example
136 /// ```
137 /// use rig::providers::gemini::{Client, self};
138 ///
139 /// // Initialize the Google Gemini client
140 /// let gemini = Client::new("your-google-gemini-api-key");
141 ///
142 /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024);
143 /// ```
144 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
145 EmbeddingModel::new(self.clone(), model, Some(ndims))
146 }
147
148 /// Create an embedding builder with the given embedding model.
149 ///
150 /// # Example
151 /// ```
152 /// use rig::providers::gemini::{Client, self};
153 ///
154 /// // Initialize the Google Gemini client
155 /// let gemini = Client::new("your-google-gemini-api-key");
156 ///
157 /// let embeddings = gemini.embeddings(gemini::embedding::EMBEDDING_GECKO_001)
158 /// .simple_document("doc0", "Hello, world!")
159 /// .simple_document("doc1", "Goodbye, world!")
160 /// .build()
161 /// .await
162 /// .expect("Failed to embed documents");
163 /// ```
164 fn embeddings<D: Embed>(
165 &self,
166 model: &str,
167 ) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
168 embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
169 }
170}
171
172impl TranscriptionClient for Client {
173 type TranscriptionModel = TranscriptionModel;
174
175 /// Create a transcription model with the given name.
176 /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct.
177 /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig)
178 fn transcription_model(&self, model: &str) -> TranscriptionModel {
179 TranscriptionModel::new(self.clone(), model)
180 }
181}
182
183impl_conversion_traits!(
184 AsImageGeneration,
185 AsAudioGeneration for Client
186);
187
188#[derive(Debug, Deserialize)]
189pub struct ApiErrorResponse {
190 pub message: String,
191}
192
193#[derive(Debug, Deserialize)]
194#[serde(untagged)]
195pub enum ApiResponse<T> {
196 Ok(T),
197 Err(ApiErrorResponse),
198}