rig/providers/gemini/
client.rs1use crate::client::{
2 self, ApiKey, Capabilities, Capable, DebugExt, Provider, ProviderBuilder, ProviderClient,
3 Transport,
4};
5use crate::http_client;
6use crate::providers::gemini::model_listing::{GeminiInteractionsModelLister, GeminiModelLister};
7use serde::Deserialize;
8use std::fmt::Debug;
9
10#[cfg(any(feature = "image", feature = "audio"))]
11use crate::client::Nothing;
12
13const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
17
18#[derive(Debug, Default, Clone)]
20pub struct GeminiExt {
21 api_key: String,
22}
23
24#[derive(Debug, Default, Clone)]
26pub struct GeminiBuilder;
27
28#[derive(Debug, Default, Clone)]
30pub struct GeminiInteractionsExt {
31 api_key: String,
32}
33
34#[derive(Debug, Default, Clone)]
36pub struct GeminiInteractionsBuilder;
37
38pub struct GeminiApiKey(String);
40
41impl<S> From<S> for GeminiApiKey
42where
43 S: Into<String>,
44{
45 fn from(value: S) -> Self {
46 Self(value.into())
47 }
48}
49
50pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
52pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
54pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
56
57impl ApiKey for GeminiApiKey {}
58
59impl DebugExt for GeminiExt {
60 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
61 std::iter::once(("api_key", (&"******") as &dyn Debug))
62 }
63}
64
65impl DebugExt for GeminiInteractionsExt {
66 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
67 std::iter::once(("api_key", (&"******") as &dyn Debug))
68 }
69}
70
71impl Provider for GeminiExt {
72 type Builder = GeminiBuilder;
73
74 const VERIFY_PATH: &'static str = "/v1beta/models";
75
76 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
77 let trimmed = path.trim_start_matches('/');
78 let separator = if trimmed.contains('?') { "&" } else { "?" };
79
80 match transport {
81 Transport::Sse => format!(
82 "{base_url}/{trimmed}{separator}alt=sse&key={}",
83 self.api_key
84 ),
85 _ => format!("{base_url}/{trimmed}{separator}key={}", self.api_key),
86 }
87 }
88}
89
90impl Provider for GeminiInteractionsExt {
91 type Builder = GeminiInteractionsBuilder;
92
93 const VERIFY_PATH: &'static str = "/v1beta/models";
94
95 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
96 let trimmed = path.trim_start_matches('/');
97 match transport {
98 Transport::Sse => {
99 if trimmed.contains('?') {
100 format!("{}/{}&alt=sse", base_url, trimmed)
101 } else {
102 format!("{}/{}?alt=sse", base_url, trimmed)
103 }
104 }
105 _ => format!("{}/{}", base_url, trimmed),
106 }
107 }
108
109 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
110 Ok(req.header("x-goog-api-key", self.api_key.clone()))
111 }
112}
113
114impl<H> Capabilities<H> for GeminiExt {
115 type Completion = Capable<super::completion::CompletionModel>;
116 type Embeddings = Capable<super::embedding::EmbeddingModel>;
117 type Transcription = Capable<super::transcription::TranscriptionModel>;
118 type ModelListing = Capable<GeminiModelLister<H>>;
119
120 #[cfg(feature = "image")]
121 type ImageGeneration = Nothing;
122 #[cfg(feature = "audio")]
123 type AudioGeneration = Nothing;
124}
125
126impl<H> Capabilities<H> for GeminiInteractionsExt {
127 type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
128 type Embeddings = Capable<super::embedding::EmbeddingModel>;
129 type Transcription = Capable<super::transcription::TranscriptionModel>;
130 type ModelListing = Capable<GeminiInteractionsModelLister<H>>;
131
132 #[cfg(feature = "image")]
133 type ImageGeneration = Nothing;
134 #[cfg(feature = "audio")]
135 type AudioGeneration = Nothing;
136}
137
138impl ProviderBuilder for GeminiBuilder {
139 type Extension<H>
140 = GeminiExt
141 where
142 H: http_client::HttpClientExt;
143 type ApiKey = GeminiApiKey;
144
145 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
146
147 fn build<H>(
148 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
149 ) -> http_client::Result<Self::Extension<H>>
150 where
151 H: http_client::HttpClientExt,
152 {
153 Ok(GeminiExt {
154 api_key: builder.get_api_key().0.clone(),
155 })
156 }
157}
158
159impl ProviderBuilder for GeminiInteractionsBuilder {
160 type Extension<H>
161 = GeminiInteractionsExt
162 where
163 H: http_client::HttpClientExt;
164 type ApiKey = GeminiApiKey;
165
166 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
167
168 fn build<H>(
169 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
170 ) -> http_client::Result<Self::Extension<H>>
171 where
172 H: http_client::HttpClientExt,
173 {
174 Ok(GeminiInteractionsExt {
175 api_key: builder.get_api_key().0.clone(),
176 })
177 }
178}
179
180impl ProviderClient for Client {
181 type Input = GeminiApiKey;
182
183 fn from_env() -> Self {
186 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
187 Self::new(api_key).unwrap()
188 }
189
190 fn from_val(input: Self::Input) -> Self {
191 Self::new(input).unwrap()
192 }
193}
194
195impl ProviderClient for InteractionsClient {
196 type Input = GeminiApiKey;
197
198 fn from_env() -> Self {
201 let api_key = std::env::var("GEMINI_API_KEY").expect("GEMINI_API_KEY not set");
202 Self::new(api_key).unwrap()
203 }
204
205 fn from_val(input: Self::Input) -> Self {
206 Self::new(input).unwrap()
207 }
208}
209
210impl<H> Client<H> {
211 pub fn interactions_api(self) -> InteractionsClient<H> {
213 let api_key = self.ext().api_key.clone();
214 self.with_ext(GeminiInteractionsExt { api_key })
215 }
216}
217
218impl<H> InteractionsClient<H> {
219 pub fn generate_content_api(self) -> Client<H> {
221 let api_key = self.ext().api_key.clone();
222 self.with_ext(GeminiExt { api_key })
223 }
224}
225
226#[derive(Debug, Deserialize)]
228pub struct ApiErrorResponse {
229 pub message: String,
230}
231
232#[derive(Debug, Deserialize)]
234#[serde(untagged)]
235pub enum ApiResponse<T> {
236 Ok(T),
237 Err(ApiErrorResponse),
238}
239
240#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_client_initialization() {
250 let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
251 let _client_from_builder: Client = Client::builder()
252 .api_key("dummy-key")
253 .build()
254 .expect("Client::builder() failed");
255 }
256}