rig_core/providers/gemini/
client.rs1use crate::client::{
2 self, ApiKey, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient, Transport,
4};
5use crate::http_client::{self};
6use crate::providers::gemini::model_listing::{GeminiInteractionsModelLister, GeminiModelLister};
7use serde::Deserialize;
8use std::fmt::Debug;
9
10const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com";
14
15#[derive(Debug, Default, Clone)]
17pub struct GeminiExt {
18 api_key: String,
19}
20
21#[derive(Debug, Default, Clone)]
23pub struct GeminiBuilder;
24
25#[derive(Debug, Default, Clone)]
27pub struct GeminiInteractionsExt {
28 api_key: String,
29}
30
31#[derive(Debug, Default, Clone)]
33pub struct GeminiInteractionsBuilder;
34
35pub struct GeminiApiKey(String);
37
38impl<S> From<S> for GeminiApiKey
39where
40 S: Into<String>,
41{
42 fn from(value: S) -> Self {
43 Self(value.into())
44 }
45}
46
47pub type Client<H = reqwest::Client> = client::Client<GeminiExt, H>;
49pub type ClientBuilder<H = crate::markers::Missing> =
51 client::ClientBuilder<GeminiBuilder, GeminiApiKey, H>;
52pub type InteractionsClient<H = reqwest::Client> = client::Client<GeminiInteractionsExt, H>;
54
55impl ApiKey for GeminiApiKey {}
56
57impl DebugExt for GeminiExt {
58 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
59 std::iter::once(("api_key", (&"******") as &dyn Debug))
60 }
61}
62
63impl DebugExt for GeminiInteractionsExt {
64 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
65 std::iter::once(("api_key", (&"******") as &dyn Debug))
66 }
67}
68
69impl Provider for GeminiExt {
70 type Builder = GeminiBuilder;
71
72 const VERIFY_PATH: &'static str = "/v1beta/models";
73
74 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
75 let trimmed = path.trim_start_matches('/');
76 let separator = if trimmed.contains('?') { "&" } else { "?" };
77
78 match transport {
79 Transport::Sse => format!(
80 "{base_url}/{trimmed}{separator}alt=sse&key={}",
81 self.api_key
82 ),
83 _ => format!("{base_url}/{trimmed}{separator}key={}", self.api_key),
84 }
85 }
86}
87
88impl Provider for GeminiInteractionsExt {
89 type Builder = GeminiInteractionsBuilder;
90
91 const VERIFY_PATH: &'static str = "/v1beta/models";
92
93 fn build_uri(&self, base_url: &str, path: &str, transport: Transport) -> String {
94 let trimmed = path.trim_start_matches('/');
95 match transport {
96 Transport::Sse => {
97 if trimmed.contains('?') {
98 format!("{}/{}&alt=sse", base_url, trimmed)
99 } else {
100 format!("{}/{}?alt=sse", base_url, trimmed)
101 }
102 }
103 _ => format!("{}/{}", base_url, trimmed),
104 }
105 }
106
107 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
108 Ok(req.header("x-goog-api-key", self.api_key.clone()))
109 }
110}
111
112impl<H> Capabilities<H> for GeminiExt {
113 type Completion = Capable<super::completion::CompletionModel<H>>;
114 type Embeddings = Capable<super::embedding::EmbeddingModel<H>>;
115 type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
116 type ModelListing = Capable<GeminiModelLister<H>>;
117
118 #[cfg(feature = "image")]
119 type ImageGeneration = Capable<super::image_generation::ImageGenerationModel<H>>;
120 #[cfg(feature = "audio")]
121 type AudioGeneration = Nothing;
122 type Rerank = Nothing;
123}
124
125impl<H> Capabilities<H> for GeminiInteractionsExt {
126 type Completion = Capable<super::interactions_api::InteractionsCompletionModel<H>>;
127 type Embeddings = Capable<super::embedding::EmbeddingModel<H>>;
128 type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
129 type ModelListing = Capable<GeminiInteractionsModelLister<H>>;
130
131 #[cfg(feature = "image")]
132 type ImageGeneration = Nothing;
133 #[cfg(feature = "audio")]
134 type AudioGeneration = Nothing;
135 type Rerank = Nothing;
136}
137
138impl ProviderBuilder for GeminiBuilder {
139 type Extension<H>
140 = GeminiExt
141 where
142 H: http_client::HttpClientExt;
143 type ApiKey = GeminiApiKey;
144
145 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
146
147 fn build<H>(
148 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
149 ) -> http_client::Result<Self::Extension<H>>
150 where
151 H: http_client::HttpClientExt,
152 {
153 Ok(GeminiExt {
154 api_key: builder.get_api_key().0.clone(),
155 })
156 }
157}
158
159impl ProviderBuilder for GeminiInteractionsBuilder {
160 type Extension<H>
161 = GeminiInteractionsExt
162 where
163 H: http_client::HttpClientExt;
164 type ApiKey = GeminiApiKey;
165
166 const BASE_URL: &'static str = GEMINI_API_BASE_URL;
167
168 fn build<H>(
169 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
170 ) -> http_client::Result<Self::Extension<H>>
171 where
172 H: http_client::HttpClientExt,
173 {
174 Ok(GeminiInteractionsExt {
175 api_key: builder.get_api_key().0.clone(),
176 })
177 }
178}
179
180impl ProviderClient for Client {
181 type Input = GeminiApiKey;
182 type Error = crate::client::ProviderClientError;
183
184 fn from_env() -> Result<Self, Self::Error> {
186 let api_key = crate::client::required_env_var("GEMINI_API_KEY")?;
187 Self::new(api_key).map_err(Into::into)
188 }
189
190 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
191 Self::new(input).map_err(Into::into)
192 }
193}
194
195impl ProviderClient for InteractionsClient {
196 type Input = GeminiApiKey;
197 type Error = crate::client::ProviderClientError;
198
199 fn from_env() -> Result<Self, Self::Error> {
201 let api_key = crate::client::required_env_var("GEMINI_API_KEY")?;
202 Self::new(api_key).map_err(Into::into)
203 }
204
205 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
206 Self::new(input).map_err(Into::into)
207 }
208}
209
210impl<H> Client<H> {
211 pub fn interactions_api(self) -> InteractionsClient<H> {
213 let api_key = self.ext().api_key.clone();
214 self.with_ext(GeminiInteractionsExt { api_key })
215 }
216}
217
218impl<H> InteractionsClient<H> {
219 pub fn generate_content_api(self) -> Client<H> {
221 let api_key = self.ext().api_key.clone();
222 self.with_ext(GeminiExt { api_key })
223 }
224}
225
226#[derive(Debug, Deserialize)]
228pub struct ApiErrorResponse {
229 pub message: String,
230}
231
232#[derive(Debug, Deserialize)]
234#[serde(untagged)]
235pub enum ApiResponse<T> {
236 Ok(T),
237 Err(ApiErrorResponse),
238}
239
240#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_client_initialization() {
250 let _client: Client = Client::new("dummy-key").expect("Client::new() failed");
251 let _client_from_builder: Client = Client::builder()
252 .api_key("dummy-key")
253 .build()
254 .expect("Client::builder() failed");
255 }
256}