rig_core/providers/gemini/
client.rs1use crate::client::{
2 self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
3 Transport,
4};
5use crate::http_client::{self};
6use crate::providers::gemini::model_listing::{GeminiInteractionsModelLister, GeminiModelLister};
7use serde::Deserialize;
8use std::fmt::Debug;
9
10#[cfg(any(feature = "image", feature = "audio"))]
11use crate::client::Nothing;
12
13const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
17
18#[derive(Debug, Default, Clone)]
20pub struct GeminiExt {
21 api_key: String,
22}
23
24#[derive(Debug, Default, Clone)]
26pub struct GeminiBuilder;
27
28#[derive(Debug, Default, Clone)]
30pub struct GeminiInteractionsExt {
31 api_key: String,
32}
33
34#[derive(Debug, Default, Clone)]
36pub struct GeminiInteractionsBuilder;
37
38pub struct GeminiApiKey(String);
40
41impl<S> From<S> for GeminiApiKey
42where
43 S: Into<String>,
44{
45 fn from(value: S) -> Self {
46 Self(value.into())
47 }
48}
49
50pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
52pub type ClientBuilder<H = crate::markers::Missing> =
54 client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
55pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
57
58impl ApiKey for GeminiApiKey {}
59
60impl DebugExt for GeminiExt {
61 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
62 std::iter::once(("api_key", (&"******") as &dyn Debug))
63 }
64}
65
66impl DebugExt for GeminiInteractionsExt {
67 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
68 std::iter::once(("api_key", (&"******") as &dyn Debug))
69 }
70}
71
72impl Provider for GeminiExt {
73 type Builder = GeminiBuilder;
74
75 const VERIFY_PATH: &'static str = "/v1beta/models";
76
77 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
78 let trimmed = path.trim_start_matches('/');
79 let separator = if trimmed.contains('?') { "&" } else { "?" };
80
81 match transport {
82 Transport::Sse => format!(
83 "{base_url}/{trimmed}{separator}alt=sse&key={}",
84 self.api_key
85 ),
86 _ => format!("{base_url}/{trimmed}{separator}key={}", self.api_key),
87 }
88 }
89}
90
91impl Provider for GeminiInteractionsExt {
92 type Builder = GeminiInteractionsBuilder;
93
94 const VERIFY_PATH: &'static str = "/v1beta/models";
95
96 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
97 let trimmed = path.trim_start_matches('/');
98 match transport {
99 Transport::Sse => {
100 if trimmed.contains('?') {
101 format!("{}/{}&alt=sse", base_url, trimmed)
102 } else {
103 format!("{}/{}?alt=sse", base_url, trimmed)
104 }
105 }
106 _ => format!("{}/{}", base_url, trimmed),
107 }
108 }
109
110 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
111 Ok(req.header("x-goog-api-key", self.api_key.clone()))
112 }
113}
114
115impl<H> Capabilities<H> for GeminiExt {
116 type Completion = Capable<super::completion::CompletionModel<H>>;
117 type Embeddings = Capable<super::embedding::EmbeddingModel<H>>;
118 type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
119 type ModelListing = Capable<GeminiModelLister<H>>;
120
121 #[cfg(feature = "image")]
122 type ImageGeneration = Nothing;
123 #[cfg(feature = "audio")]
124 type AudioGeneration = Nothing;
125}
126
127impl<H> Capabilities<H> for GeminiInteractionsExt {
128 type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
129 type Embeddings = Capable<super::embedding::EmbeddingModel<H>>;
130 type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
131 type ModelListing = Capable<GeminiInteractionsModelLister<H>>;
132
133 #[cfg(feature = "image")]
134 type ImageGeneration = Nothing;
135 #[cfg(feature = "audio")]
136 type AudioGeneration = Nothing;
137}
138
139impl ProviderBuilder for GeminiBuilder {
140 type Extension<H>
141 = GeminiExt
142 where
143 H: http_client::HttpClientExt;
144 type ApiKey = GeminiApiKey;
145
146 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
147
148 fn build<H>(
149 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
150 ) -> http_client::Result<Self::Extension<H>>
151 where
152 H: http_client::HttpClientExt,
153 {
154 Ok(GeminiExt {
155 api_key: builder.get_api_key().0.clone(),
156 })
157 }
158}
159
160impl ProviderBuilder for GeminiInteractionsBuilder {
161 type Extension<H>
162 = GeminiInteractionsExt
163 where
164 H: http_client::HttpClientExt;
165 type ApiKey = GeminiApiKey;
166
167 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
168
169 fn build<H>(
170 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
171 ) -> http_client::Result<Self::Extension<H>>
172 where
173 H: http_client::HttpClientExt,
174 {
175 Ok(GeminiInteractionsExt {
176 api_key: builder.get_api_key().0.clone(),
177 })
178 }
179}
180
181impl ProviderClient for Client {
182 type Input = GeminiApiKey;
183 type Error = crate::client::ProviderClientError;
184
185 fn from_env() -> Result<Self, Self::Error> {
187 let api_key = crate::client::required_env_var("GEMINI_API_KEY")?;
188 Self::new(api_key).map_err(Into::into)
189 }
190
191 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
192 Self::new(input).map_err(Into::into)
193 }
194}
195
196impl ProviderClient for InteractionsClient {
197 type Input = GeminiApiKey;
198 type Error = crate::client::ProviderClientError;
199
200 fn from_env() -> Result<Self, Self::Error> {
202 let api_key = crate::client::required_env_var("GEMINI_API_KEY")?;
203 Self::new(api_key).map_err(Into::into)
204 }
205
206 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
207 Self::new(input).map_err(Into::into)
208 }
209}
210
211impl<H> Client<H> {
212 pub fn interactions_api(self) -> InteractionsClient<H> {
214 let api_key = self.ext().api_key.clone();
215 self.with_ext(GeminiInteractionsExt { api_key })
216 }
217}
218
219impl<H> InteractionsClient<H> {
220 pub fn generate_content_api(self) -> Client<H> {
222 let api_key = self.ext().api_key.clone();
223 self.with_ext(GeminiExt { api_key })
224 }
225}
226
227#[derive(Debug, Deserialize)]
229pub struct ApiErrorResponse {
230 pub message: String,
231}
232
233#[derive(Debug, Deserialize)]
235#[serde(untagged)]
236pub enum ApiResponse<T> {
237 Ok(T),
238 Err(ApiErrorResponse),
239}
240
241#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_client_initialization() {
251 let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
252 let _client_from_builder: Client = Client::builder()
253 .api_key("dummy-key")
254 .build()
255 .expect("Client::builder() failed");
256 }
257}