1use crate::{
2 backend,
3 batch::{BatchBuilder, BatchHandle},
4 cache::{CacheBuilder, CachedContentHandle},
5 embedding::{
6 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
7 EmbedBuilder, EmbedContentRequest,
8 },
9 files::{
10 handle::FileHandle,
11 model::{File, ListFilesResponse},
12 },
13 generation::{ContentBuilder, GenerateContentRequest, GenerationResponse},
14};
15use eventsource_stream::EventStreamError;
16use futures::Stream;
17#[cfg(feature = "vertex")]
18use google_cloud_aiplatform_v1::client::PredictionService;
19#[cfg(feature = "vertex")]
20use google_cloud_auth::credentials::{self, Credentials};
21use mime::Mime;
22use reqwest::{ClientBuilder, header::InvalidHeaderValue};
23use serde::{Deserialize, Serialize};
24use snafu::{ResultExt, Snafu};
25use std::{
26 fmt::{self, Formatter},
27 sync::{Arc, LazyLock},
28};
29use tracing::{Level, Span, instrument};
30use url::Url;
31
32use crate::batch::model::*;
33use crate::cache::model::*;
34
35static DEFAULT_BASE_URL: LazyLock<Url> = LazyLock::new(|| {
36 Url::parse("https://generativelanguage.googleapis.com/v1beta/")
37 .expect("unreachable error: failed to parse default base URL")
38});
39static V1_BASE_URL: LazyLock<Url> = LazyLock::new(|| {
40 Url::parse("https://generativelanguage.googleapis.com/v1/")
41 .expect("unreachable error: failed to parse v1 base URL")
42});
43
44#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
49pub enum Model {
50 #[serde(rename = "models/gemini-3.1-pro-preview")]
52 Gemini31ProPreview,
53 #[serde(rename = "models/gemini-3.1-flash-lite-preview")]
54 Gemini31FlashLitePreview,
55
56 #[serde(rename = "models/gemini-3-pro-preview")]
58 Gemini3ProPreview,
59 #[serde(rename = "models/gemini-3-pro-image-preview")]
60 Gemini3ProImagePreview,
61 #[serde(rename = "models/gemini-3-flash-preview")]
62 Gemini3FlashPreview,
63
64 #[serde(rename = "models/gemini-2.5-pro")]
66 Gemini25Pro,
67 #[serde(rename = "models/gemini-2.5-pro-preview-tts")]
68 Gemini25ProPreviewTts,
69 #[default]
70 #[serde(rename = "models/gemini-2.5-flash")]
71 Gemini25Flash,
72 #[serde(rename = "models/gemini-2.5-flash-preview-09-2025")]
73 Gemini25FlashPreview092025,
74 #[serde(rename = "models/gemini-2.5-flash-image")]
75 Gemini25FlashImage,
76 #[deprecated(note = "Use Model::Gemini25FlashImage instead")]
78 #[serde(rename = "models/gemini-2.5-flash-image-preview")]
79 Gemini25FlashImagePreview,
80 #[serde(rename = "models/gemini-2.5-flash-native-audio-preview-12-2025")]
81 Gemini25FlashLive122025,
82 #[serde(rename = "models/gemini-2.5-flash-native-audio-preview-09-2025")]
83 Gemini25FlashLive092025,
84 #[serde(rename = "models/gemini-2.5-flash-preview-tts")]
85 Gemini25FlashPreviewTts,
86 #[serde(rename = "models/gemini-2.5-flash-lite")]
87 Gemini25FlashLite,
88 #[serde(rename = "models/gemini-2.5-flash-lite-preview-09-2025")]
89 Gemini25FlashLitePreview092025,
90
91 #[deprecated(note = "Gemini 2.0 models shut down March 31, 2026")]
93 #[serde(rename = "models/gemini-2.0-flash")]
94 Gemini20Flash,
95 #[deprecated(note = "Gemini 2.0 models shut down March 31, 2026")]
96 #[serde(rename = "models/gemini-2.0-flash-001")]
97 Gemini20Flash001,
98 #[deprecated(note = "Gemini 2.0 models shut down March 31, 2026")]
99 #[serde(rename = "models/gemini-2.0-flash-exp")]
100 Gemini20FlashExp,
101 #[deprecated(note = "Gemini 2.0 models shut down March 31, 2026")]
102 #[serde(rename = "models/gemini-2.0-flash-lite")]
103 Gemini20FlashLite,
104 #[deprecated(note = "Gemini 2.0 models shut down March 31, 2026")]
105 #[serde(rename = "models/gemini-2.0-flash-lite-001")]
106 Gemini20FlashLite001,
107
108 #[serde(rename = "models/gemini-embedding-001")]
111 GeminiEmbedding001,
112 #[deprecated(note = "Use Model::GeminiEmbedding001 (gemini-embedding-001) instead")]
114 #[serde(rename = "models/text-embedding-004")]
115 TextEmbedding004,
116
117 #[serde(untagged)]
119 Custom(String),
120}
121
122impl Model {
123 pub fn as_str(&self) -> &str {
124 #[allow(deprecated)]
125 match self {
126 Model::Gemini31ProPreview => "models/gemini-3.1-pro-preview",
127 Model::Gemini31FlashLitePreview => "models/gemini-3.1-flash-lite-preview",
128 Model::Gemini3ProPreview => "models/gemini-3-pro-preview",
129 Model::Gemini3ProImagePreview => "models/gemini-3-pro-image-preview",
130 Model::Gemini3FlashPreview => "models/gemini-3-flash-preview",
131 Model::Gemini25Pro => "models/gemini-2.5-pro",
132 Model::Gemini25ProPreviewTts => "models/gemini-2.5-pro-preview-tts",
133 Model::Gemini25Flash => "models/gemini-2.5-flash",
134 Model::Gemini25FlashPreview092025 => "models/gemini-2.5-flash-preview-09-2025",
135 Model::Gemini25FlashImage => "models/gemini-2.5-flash-image",
136 Model::Gemini25FlashImagePreview => "models/gemini-2.5-flash-image-preview",
137 Model::Gemini25FlashLive122025 => {
138 "models/gemini-2.5-flash-native-audio-preview-12-2025"
139 }
140 Model::Gemini25FlashLive092025 => {
141 "models/gemini-2.5-flash-native-audio-preview-09-2025"
142 }
143 Model::Gemini25FlashPreviewTts => "models/gemini-2.5-flash-preview-tts",
144 Model::Gemini25FlashLite => "models/gemini-2.5-flash-lite",
145 Model::Gemini25FlashLitePreview092025 => "models/gemini-2.5-flash-lite-preview-09-2025",
146 Model::Gemini20Flash => "models/gemini-2.0-flash",
147 Model::Gemini20Flash001 => "models/gemini-2.0-flash-001",
148 Model::Gemini20FlashExp => "models/gemini-2.0-flash-exp",
149 Model::Gemini20FlashLite => "models/gemini-2.0-flash-lite",
150 Model::Gemini20FlashLite001 => "models/gemini-2.0-flash-lite-001",
151 Model::GeminiEmbedding001 => "models/gemini-embedding-001",
152 Model::TextEmbedding004 => "models/text-embedding-004",
153 Model::Custom(model) => model,
154 }
155 }
156
157 pub fn vertex_model_path(&self, project_id: &str, location: &str) -> String {
158 #[allow(deprecated)]
159 let model_id = match self {
160 Model::Gemini31ProPreview => "gemini-3.1-pro-preview",
161 Model::Gemini31FlashLitePreview => "gemini-3.1-flash-lite-preview",
162 Model::Gemini3ProPreview => "gemini-3-pro-preview",
163 Model::Gemini3ProImagePreview => "gemini-3-pro-image-preview",
164 Model::Gemini3FlashPreview => "gemini-3-flash-preview",
165 Model::Gemini25Pro => "gemini-2.5-pro",
166 Model::Gemini25ProPreviewTts => "gemini-2.5-pro-preview-tts",
167 Model::Gemini25Flash => "gemini-2.5-flash",
168 Model::Gemini25FlashPreview092025 => "gemini-2.5-flash-preview-09-2025",
169 Model::Gemini25FlashImage => "gemini-2.5-flash-image",
170 Model::Gemini25FlashImagePreview => "gemini-2.5-flash-image-preview",
171 Model::Gemini25FlashLive122025 => "gemini-2.5-flash-native-audio-preview-12-2025",
172 Model::Gemini25FlashLive092025 => "gemini-2.5-flash-native-audio-preview-09-2025",
173 Model::Gemini25FlashPreviewTts => "gemini-2.5-flash-preview-tts",
174 Model::Gemini25FlashLite => "gemini-2.5-flash-lite",
175 Model::Gemini25FlashLitePreview092025 => "gemini-2.5-flash-lite-preview-09-2025",
176 Model::Gemini20Flash => "gemini-2.0-flash",
177 Model::Gemini20Flash001 => "gemini-2.0-flash-001",
178 Model::Gemini20FlashExp => "gemini-2.0-flash-exp",
179 Model::Gemini20FlashLite => "gemini-2.0-flash-lite",
180 Model::Gemini20FlashLite001 => "gemini-2.0-flash-lite-001",
181 Model::GeminiEmbedding001 => "gemini-embedding-001",
182 Model::TextEmbedding004 => "text-embedding-004",
183 Model::Custom(model) => {
184 if model.starts_with("projects/") {
185 return model.clone();
186 }
187 if model.starts_with("publishers/") {
188 return format!("projects/{project_id}/locations/{location}/{model}");
189 }
190 model.strip_prefix("models/").unwrap_or(model)
191 }
192 };
193 format!("projects/{project_id}/locations/{location}/publishers/google/models/{model_id}")
194 }
195}
196
197impl From<String> for Model {
198 #[allow(deprecated)]
199 fn from(model: String) -> Self {
200 let bare = model.strip_prefix("models/").unwrap_or(&model);
202 match bare {
203 "gemini-3.1-pro-preview" => Self::Gemini31ProPreview,
205 "gemini-3-pro-preview" => Self::Gemini3ProPreview,
207 "gemini-3-pro-image-preview" => Self::Gemini3ProImagePreview,
208 "gemini-3-flash-preview" => Self::Gemini3FlashPreview,
209 "gemini-2.5-pro" => Self::Gemini25Pro,
211 "gemini-2.5-pro-preview-tts" => Self::Gemini25ProPreviewTts,
212 "gemini-2.5-flash" => Self::Gemini25Flash,
213 "gemini-2.5-flash-preview-09-2025" => Self::Gemini25FlashPreview092025,
214 "gemini-2.5-flash-image" => Self::Gemini25FlashImage,
215 "gemini-2.5-flash-image-preview" => Self::Gemini25FlashImagePreview,
216 "gemini-2.5-flash-native-audio-preview-12-2025" => Self::Gemini25FlashLive122025,
217 "gemini-2.5-flash-native-audio-preview-09-2025" => Self::Gemini25FlashLive092025,
218 "gemini-2.5-flash-preview-tts" => Self::Gemini25FlashPreviewTts,
219 "gemini-2.5-flash-lite" => Self::Gemini25FlashLite,
220 "gemini-2.5-flash-lite-preview-09-2025" => Self::Gemini25FlashLitePreview092025,
221 "gemini-2.0-flash" => Self::Gemini20Flash,
223 "gemini-2.0-flash-001" => Self::Gemini20Flash001,
224 "gemini-2.0-flash-exp" => Self::Gemini20FlashExp,
225 "gemini-2.0-flash-lite" => Self::Gemini20FlashLite,
226 "gemini-2.0-flash-lite-001" => Self::Gemini20FlashLite001,
227 "gemini-embedding-001" => Self::GeminiEmbedding001,
229 "text-embedding-004" => Self::TextEmbedding004,
230 _ => Self::Custom(model),
231 }
232 }
233}
234
235impl fmt::Display for Model {
236 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
237 #[allow(deprecated)]
238 match self {
239 Model::Custom(model) => {
240 if model.starts_with("models/")
242 || model.starts_with("projects/")
243 || model.starts_with("publishers/")
244 {
245 write!(f, "{model}")
246 } else {
247 write!(f, "models/{model}")
248 }
249 }
250 other => write!(f, "{}", other.as_str()),
251 }
252 }
253}
254
255#[derive(Debug, Snafu)]
260#[snafu(visibility(pub))]
261pub enum Error {
262 #[snafu(display("failed to parse API key"))]
263 InvalidApiKey {
264 source: InvalidHeaderValue,
265 },
266
267 #[snafu(display("failed to construct URL (probably incorrect model name): {suffix}"))]
268 ConstructUrl {
269 source: url::ParseError,
270 suffix: String,
271 },
272
273 #[snafu(display("failed to perform request: {source}"))]
274 PerformRequestNew {
275 source: reqwest::Error,
276 },
277
278 #[snafu(display("failed to perform request to '{url}'"))]
279 PerformRequest {
280 source: reqwest::Error,
281 url: Url,
282 },
283
284 #[snafu(display("bad response from server; code {code}; description: {}", description.as_deref().unwrap_or("none")))]
285 BadResponse {
286 code: u16,
287 description: Option<String>,
288 },
289
290 MissingResponseHeader {
291 header: String,
292 },
293
294 #[snafu(display("failed to obtain stream SSE part"))]
295 BadPart {
296 source: EventStreamError<reqwest::Error>,
297 },
298
299 #[snafu(display("failed to deserialize JSON response"))]
300 Deserialize {
301 source: serde_json::Error,
302 },
303
304 #[snafu(display("failed to generate content"))]
305 DecodeResponse {
306 source: reqwest::Error,
307 },
308
309 #[snafu(display("failed to parse URL"))]
310 UrlParse {
311 source: url::ParseError,
312 },
313
314 #[snafu(display("failed to build google cloud credentials"))]
315 #[cfg(feature = "vertex")]
316 GoogleCloudAuth {
317 source: google_cloud_auth::build_errors::Error,
318 },
319
320 #[snafu(display("failed to obtain google cloud auth headers"))]
321 #[cfg(feature = "vertex")]
322 GoogleCloudCredentialHeaders {
323 source: google_cloud_auth::errors::CredentialsError,
324 },
325
326 #[snafu(display("google cloud credentials returned NotModified without cached headers"))]
327 GoogleCloudCredentialHeadersUnavailable,
328
329 #[snafu(display("failed to parse google cloud credentials JSON"))]
330 GoogleCloudCredentialParse {
331 source: serde_json::Error,
332 },
333
334 #[snafu(display("failed to build google cloud vertex client"))]
335 #[cfg(feature = "vertex")]
336 GoogleCloudClientBuild {
337 source: google_cloud_gax::client_builder::Error,
338 },
339
340 #[snafu(display("failed to send google cloud vertex request"))]
341 #[cfg(feature = "vertex")]
342 GoogleCloudRequest {
343 source: google_cloud_aiplatform_v1::Error,
344 },
345
346 #[snafu(display("failed to serialize google cloud request"))]
347 GoogleCloudRequestSerialize {
348 source: serde_json::Error,
349 },
350
351 #[snafu(display("failed to deserialize google cloud request"))]
352 GoogleCloudRequestDeserialize {
353 source: serde_json::Error,
354 },
355
356 #[snafu(display("failed to serialize google cloud response"))]
357 GoogleCloudResponseSerialize {
358 source: serde_json::Error,
359 },
360
361 #[snafu(display("failed to deserialize google cloud response"))]
362 GoogleCloudResponseDeserialize {
363 source: serde_json::Error,
364 },
365
366 #[snafu(display("google cloud request payload is not an object"))]
367 GoogleCloudRequestNotObject,
368
369 #[snafu(display("google cloud configuration is required for this authentication mode"))]
370 MissingGoogleCloudConfig,
371
372 #[snafu(display("google cloud authentication is required for this configuration"))]
373 MissingGoogleCloudAuth,
374
375 #[snafu(display("service account JSON is missing required field 'project_id'"))]
376 MissingGoogleCloudProjectId,
377
378 #[snafu(display("api key is required for this configuration"))]
379 MissingApiKey,
380
381 #[snafu(display(
382 "operation '{operation}' is not supported with the google cloud sdk backend (PredictionService currently exposes generateContent/embedContent only)"
383 ))]
384 GoogleCloudUnsupported {
385 operation: &'static str,
386 },
387
388 #[snafu(display("failed to create tokio runtime for google cloud client"))]
389 TokioRuntime {
390 source: std::io::Error,
391 },
392
393 #[snafu(display("google cloud client initialization thread panicked"))]
394 GoogleCloudInitThreadPanicked,
395
396 #[snafu(display("I/O error during file operations"))]
397 Io {
398 source: std::io::Error,
399 },
400
401 #[snafu(display("invalid generation config: {message}"))]
402 InvalidGenerationConfig {
403 message: String,
404 },
405}
406
407pub struct GeminiClient {
416 pub model: Model,
417 backend: Box<dyn backend::GeminiBackend>,
418}
419
420impl std::fmt::Debug for GeminiClient {
421 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
422 f.debug_struct("GeminiClient")
423 .field("model", &self.model)
424 .field("backend", &self.backend)
425 .finish()
426 }
427}
428
429impl GeminiClient {
430 fn with_studio(model: Model, studio: backend::studio::StudioBackend) -> Self {
432 Self { model, backend: Box::new(studio) }
433 }
434
435 #[cfg(feature = "vertex")]
437 fn with_vertex(model: Model, vertex: backend::vertex::VertexBackend) -> Self {
438 Self { model, backend: Box::new(vertex) }
439 }
440
441 #[instrument(skip_all, fields(
444 model,
445 messages.parts.count = request.contents.len(),
446 tools.present = request.tools.is_some(),
447 system.instruction.present = request.system_instruction.is_some(),
448 cached.content.present = request.cached_content.is_some(),
449 usage.prompt_tokens,
450 usage.candidates_tokens,
451 usage.thoughts_tokens,
452 usage.cached_content_tokens,
453 usage.total_tokens,
454 ), ret(level = Level::TRACE), err)]
455 pub(crate) async fn generate_content_raw(
456 &self,
457 request: GenerateContentRequest,
458 ) -> Result<GenerationResponse, Error> {
459 if let Some(ref gc) = request.generation_config {
460 gc.validate().map_err(|message| Error::InvalidGenerationConfig { message })?;
461 }
462
463 let response = self.backend.generate_content(request).await?;
464
465 if let Some(usage) = &response.usage_metadata {
466 #[rustfmt::skip]
467 Span::current()
468 .record("usage.prompt_tokens", usage.prompt_token_count)
469 .record("usage.candidates_tokens", usage.candidates_token_count)
470 .record("usage.thoughts_tokens", usage.thoughts_token_count)
471 .record("usage.cached_content_tokens", usage.cached_content_token_count)
472 .record("usage.total_tokens", usage.total_token_count);
473 tracing::debug!("generation usage evaluated");
474 }
475
476 Ok(response)
477 }
478
479 #[instrument(skip_all, fields(
480 model,
481 messages.parts.count = request.contents.len(),
482 tools.present = request.tools.is_some(),
483 system.instruction.present = request.system_instruction.is_some(),
484 cached.content.present = request.cached_content.is_some(),
485 ), err)]
486 pub(crate) async fn generate_content_stream(
487 &self,
488 request: GenerateContentRequest,
489 ) -> Result<backend::BackendStream<GenerationResponse>, Error> {
490 if let Some(ref gc) = request.generation_config {
491 gc.validate().map_err(|message| Error::InvalidGenerationConfig { message })?;
492 }
493
494 self.backend.generate_content_stream(request).await
495 }
496
497 #[instrument(skip_all, fields(
498 model,
499 task.type = request.task_type.as_ref().map(|t| format!("{:?}", t)),
500 task.title = request.title,
501 task.output.dimensionality = request.output_dimensionality,
502 ))]
503 pub(crate) async fn embed_content(
504 &self,
505 request: EmbedContentRequest,
506 ) -> Result<ContentEmbeddingResponse, Error> {
507 self.backend.embed_content(request).await
508 }
509
510 #[instrument(skip_all, fields(batch.size = request.requests.len()))]
511 pub(crate) async fn embed_content_batch(
512 &self,
513 request: BatchEmbedContentsRequest,
514 ) -> Result<BatchContentEmbeddingResponse, Error> {
515 self.backend.batch_embed_contents(request).await
516 }
517
518 #[instrument(skip_all, fields(
519 batch.display_name = request.batch.display_name,
520 batch.size = request.batch.input_config.batch_size(),
521 ))]
522 pub(crate) async fn batch_generate_content(
523 &self,
524 request: BatchGenerateContentRequest,
525 ) -> Result<BatchGenerateContentResponse, Error> {
526 self.backend.batch_generate_content(request).await
527 }
528
529 #[instrument(skip_all, fields(operation.name = name))]
530 pub(crate) async fn get_batch_operation<T: serde::de::DeserializeOwned>(
531 &self,
532 name: &str,
533 ) -> Result<T, Error> {
534 let value = self.backend.get_batch_operation(name).await?;
535 serde_json::from_value(value).context(DeserializeSnafu)
536 }
537
538 #[instrument(skip_all, fields(page.size = page_size, page.token.present = page_token.is_some()))]
539 pub(crate) async fn list_batch_operations(
540 &self,
541 page_size: Option<u32>,
542 page_token: Option<String>,
543 ) -> Result<ListBatchesResponse, Error> {
544 self.backend.list_batch_operations(page_size, page_token).await
545 }
546
547 #[instrument(skip_all, fields(page.size = page_size, page.token.present = page_token.is_some()))]
548 pub(crate) async fn list_files(
549 &self,
550 page_size: Option<u32>,
551 page_token: Option<String>,
552 ) -> Result<ListFilesResponse, Error> {
553 self.backend.list_files(page_size, page_token).await
554 }
555
556 #[instrument(skip_all, fields(operation.name = name))]
557 pub(crate) async fn cancel_batch_operation(&self, name: &str) -> Result<(), Error> {
558 self.backend.cancel_batch_operation(name).await
559 }
560
561 #[instrument(skip_all, fields(operation.name = name))]
562 pub(crate) async fn delete_batch_operation(&self, name: &str) -> Result<(), Error> {
563 self.backend.delete_batch_operation(name).await
564 }
565
566 #[instrument(skip_all, fields(
567 file.size = file_bytes.len(),
568 mime.type = mime_type.to_string(),
569 file.display_name = display_name.as_deref(),
570 ))]
571 pub(crate) async fn upload_file(
572 &self,
573 display_name: Option<String>,
574 file_bytes: Vec<u8>,
575 mime_type: Mime,
576 ) -> Result<File, Error> {
577 self.backend.upload_file(display_name, file_bytes, mime_type).await
578 }
579
580 #[instrument(skip_all, fields(file.name = name))]
581 pub(crate) async fn get_file(&self, name: &str) -> Result<File, Error> {
582 self.backend.get_file(name).await
583 }
584
585 #[instrument(skip_all, fields(file.name = name))]
586 pub(crate) async fn delete_file(&self, name: &str) -> Result<(), Error> {
587 self.backend.delete_file(name).await
588 }
589
590 #[instrument(skip_all, fields(file.name = name))]
591 pub(crate) async fn download_file(&self, name: &str) -> Result<Vec<u8>, Error> {
592 self.backend.download_file(name).await
593 }
594
595 pub(crate) async fn create_cached_content(
596 &self,
597 cached_content: CreateCachedContentRequest,
598 ) -> Result<CachedContent, Error> {
599 self.backend.create_cached_content(cached_content).await
600 }
601
602 pub(crate) async fn get_cached_content(&self, name: &str) -> Result<CachedContent, Error> {
603 self.backend.get_cached_content(name).await
604 }
605
606 pub(crate) async fn update_cached_content(
607 &self,
608 name: &str,
609 expiration: CacheExpirationRequest,
610 ) -> Result<CachedContent, Error> {
611 self.backend.update_cached_content(name, expiration).await
612 }
613
614 pub(crate) async fn delete_cached_content(&self, name: &str) -> Result<(), Error> {
615 self.backend.delete_cached_content(name).await
616 }
617
618 pub(crate) async fn list_cached_contents(
619 &self,
620 page_size: Option<i32>,
621 page_token: Option<String>,
622 ) -> Result<ListCachedContentsResponse, Error> {
623 self.backend.list_cached_contents(page_size, page_token).await
624 }
625
626 #[instrument(skip_all, fields(page.size = page_size, page.token.present = page_token.is_some()))]
629 pub(crate) async fn list_models(
630 &self,
631 page_size: Option<u32>,
632 page_token: Option<String>,
633 ) -> Result<crate::model_info::ListModelsResponse, Error> {
634 self.backend.list_models(page_size, page_token).await
635 }
636
637 #[instrument(skip_all, fields(model.name = name))]
638 pub(crate) async fn get_model(
639 &self,
640 name: &str,
641 ) -> Result<crate::model_info::ModelInfo, Error> {
642 self.backend.get_model(name).await
643 }
644}
645
646#[cfg(feature = "vertex")]
651#[derive(Debug, Clone)]
652enum GoogleCloudAuth {
653 ApiKey(String),
654 Credentials(Credentials),
655}
656
657#[cfg(feature = "vertex")]
658impl GoogleCloudAuth {
659 fn credentials(&self) -> Result<Credentials, Error> {
660 match self {
661 GoogleCloudAuth::ApiKey(api_key) => {
662 Ok(credentials::api_key_credentials::Builder::new(api_key).build())
663 }
664 GoogleCloudAuth::Credentials(credentials) => Ok(credentials.clone()),
665 }
666 }
667}
668
669#[cfg(feature = "vertex")]
670#[derive(Debug, Clone)]
671struct GoogleCloudConfig {
672 project_id: String,
673 location: String,
674}
675
676#[cfg(feature = "vertex")]
677impl GoogleCloudConfig {
678 fn endpoint(&self) -> String {
679 if self.location == "global" {
680 "https://aiplatform.googleapis.com".to_string()
681 } else {
682 format!("https://{}-aiplatform.googleapis.com", self.location)
683 }
684 }
685}
686
687#[cfg(feature = "vertex")]
688fn extract_service_account_project_id(service_account_json: &str) -> Result<String, Error> {
689 let value: serde_json::Value =
690 serde_json::from_str(service_account_json).context(GoogleCloudCredentialParseSnafu)?;
691
692 let project_id = value
693 .get("project_id")
694 .and_then(serde_json::Value::as_str)
695 .map(str::trim)
696 .filter(|s| !s.is_empty())
697 .ok_or(Error::MissingGoogleCloudProjectId)?;
698
699 Ok(project_id.to_string())
700}
701
702#[cfg(feature = "vertex")]
703fn build_vertex_prediction_service(
704 endpoint: String,
705 credentials: Credentials,
706) -> Result<PredictionService, Error> {
707 let build_in_runtime =
708 |endpoint: String, credentials: Credentials| -> Result<PredictionService, Error> {
709 let runtime = tokio::runtime::Runtime::new().context(TokioRuntimeSnafu)?;
710 runtime
711 .block_on(
712 PredictionService::builder()
713 .with_endpoint(endpoint)
714 .with_credentials(credentials)
715 .build(),
716 )
717 .context(GoogleCloudClientBuildSnafu)
718 };
719
720 if tokio::runtime::Handle::try_current().is_ok() {
721 let worker = std::thread::Builder::new()
722 .name("adk-gemini-vertex-init".to_string())
723 .spawn(move || build_in_runtime(endpoint, credentials))
724 .map_err(|source| Error::TokioRuntime { source })?;
725
726 return worker.join().map_err(|_| Error::GoogleCloudInitThreadPanicked)?;
727 }
728
729 build_in_runtime(endpoint, credentials)
730}
731
732pub struct GeminiBuilder {
753 model: Model,
754 client_builder: ClientBuilder,
755 base_url: Url,
756 #[cfg(feature = "vertex")]
757 google_cloud: Option<GoogleCloudConfig>,
758 api_key: Option<String>,
759 #[cfg(feature = "vertex")]
760 google_cloud_auth: Option<GoogleCloudAuth>,
761}
762
763impl GeminiBuilder {
764 pub fn new<K: Into<String>>(key: K) -> Self {
765 Self {
766 model: Model::default(),
767 client_builder: ClientBuilder::default(),
768 base_url: DEFAULT_BASE_URL.clone(),
769 #[cfg(feature = "vertex")]
770 google_cloud: None,
771 api_key: Some(key.into()),
772 #[cfg(feature = "vertex")]
773 google_cloud_auth: None,
774 }
775 }
776
777 pub fn with_model<M: Into<Model>>(mut self, model: M) -> Self {
778 self.model = model.into();
779 self
780 }
781
782 pub fn with_http_client(mut self, client_builder: ClientBuilder) -> Self {
783 self.client_builder = client_builder;
784 self
785 }
786
787 pub fn with_base_url(mut self, base_url: Url) -> Self {
788 self.base_url = base_url;
789 #[cfg(feature = "vertex")]
790 {
791 self.google_cloud = None;
792 self.google_cloud_auth = None;
793 }
794 self
795 }
796
797 #[cfg(feature = "vertex")]
798 pub fn with_service_account_json(mut self, service_account_json: &str) -> Result<Self, Error> {
799 let value =
800 serde_json::from_str(service_account_json).context(GoogleCloudCredentialParseSnafu)?;
801 let credentials = google_cloud_auth::credentials::service_account::Builder::new(value)
802 .build()
803 .context(GoogleCloudAuthSnafu)?;
804 self.google_cloud_auth = Some(GoogleCloudAuth::Credentials(credentials));
805 Ok(self)
806 }
807
808 #[cfg(feature = "vertex")]
809 pub fn with_google_cloud<P: Into<String>, L: Into<String>>(
810 mut self,
811 project_id: P,
812 location: L,
813 ) -> Self {
814 self.google_cloud =
815 Some(GoogleCloudConfig { project_id: project_id.into(), location: location.into() });
816 self
817 }
818
819 #[cfg(feature = "vertex")]
820 pub fn with_google_cloud_adc(mut self) -> Result<Self, Error> {
821 let credentials = google_cloud_auth::credentials::Builder::default()
822 .build()
823 .context(GoogleCloudAuthSnafu)?;
824 self.google_cloud_auth = Some(GoogleCloudAuth::Credentials(credentials));
825 Ok(self)
826 }
827
828 #[cfg(feature = "vertex")]
829 pub fn with_google_cloud_wif_json(mut self, wif_json: &str) -> Result<Self, Error> {
830 let value = serde_json::from_str(wif_json).context(GoogleCloudCredentialParseSnafu)?;
831 let credentials = google_cloud_auth::credentials::external_account::Builder::new(value)
832 .build()
833 .context(GoogleCloudAuthSnafu)?;
834 self.google_cloud_auth = Some(GoogleCloudAuth::Credentials(credentials));
835 Ok(self)
836 }
837
838 pub fn build(self) -> Result<Gemini, Error> {
840 #[cfg(feature = "vertex")]
841 {
842 if self.google_cloud.is_none() && self.google_cloud_auth.is_some() {
843 return MissingGoogleCloudConfigSnafu.fail();
844 }
845
846 if let Some(config) = &self.google_cloud {
848 let model = Model::Custom(
849 self.model.vertex_model_path(&config.project_id, &config.location),
850 );
851 let google_cloud_auth = match self.google_cloud_auth {
852 Some(auth) => auth,
853 None => match self.api_key {
854 Some(api_key) if !api_key.is_empty() => GoogleCloudAuth::ApiKey(api_key),
855 _ => return MissingGoogleCloudAuthSnafu.fail(),
856 },
857 };
858 let credentials = google_cloud_auth.credentials()?;
859 let endpoint = config.endpoint();
860 let prediction =
861 build_vertex_prediction_service(endpoint.clone(), credentials.clone())?;
862
863 let vertex = backend::vertex::VertexBackend::new(
864 model.clone(),
865 prediction,
866 credentials,
867 endpoint,
868 );
869
870 return Ok(Gemini { client: Arc::new(GeminiClient::with_vertex(model, vertex)) });
871 }
872 }
873
874 let api_key = self.api_key.ok_or(Error::MissingApiKey)?;
876 if api_key.is_empty() {
877 return MissingApiKeySnafu.fail();
878 }
879
880 let studio =
881 backend::studio::StudioBackend::new(&api_key, self.model.clone(), self.base_url)?;
882
883 Ok(Gemini { client: Arc::new(GeminiClient::with_studio(self.model, studio)) })
884 }
885}
886
887pub struct Gemini {
892 client: Arc<GeminiClient>,
893}
894
895impl Gemini {
896 pub fn new<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
898 Self::with_model(api_key, Model::default())
899 }
900
901 pub fn pro<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
903 Self::with_model(api_key, Model::Gemini31ProPreview)
904 }
905
906 pub fn with_model<K: AsRef<str>, M: Into<Model>>(api_key: K, model: M) -> Result<Self, Error> {
908 Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.clone())
909 }
910
911 pub fn with_v1<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
913 Self::with_model_and_base_url(api_key, Model::default(), V1_BASE_URL.clone())
914 }
915
916 pub fn with_model_v1<K: AsRef<str>, M: Into<Model>>(
918 api_key: K,
919 model: M,
920 ) -> Result<Self, Error> {
921 Self::with_model_and_base_url(api_key, model, V1_BASE_URL.clone())
922 }
923
924 pub fn with_base_url<K: AsRef<str>>(api_key: K, base_url: Url) -> Result<Self, Error> {
926 Self::with_model_and_base_url(api_key, Model::default(), base_url)
927 }
928
929 #[cfg(feature = "vertex")]
931 pub fn with_google_cloud<K: AsRef<str>, P: AsRef<str>, L: AsRef<str>>(
932 api_key: K,
933 project_id: P,
934 location: L,
935 ) -> Result<Self, Error> {
936 Self::with_google_cloud_model(api_key, project_id, location, Model::default())
937 }
938
939 #[cfg(feature = "vertex")]
941 pub fn with_google_cloud_model<K: AsRef<str>, P: AsRef<str>, L: AsRef<str>, M: Into<Model>>(
942 api_key: K,
943 project_id: P,
944 location: L,
945 model: M,
946 ) -> Result<Self, Error> {
947 GeminiBuilder::new(api_key.as_ref())
948 .with_model(model)
949 .with_google_cloud(project_id.as_ref(), location.as_ref())
950 .build()
951 }
952
953 #[cfg(feature = "vertex")]
955 pub fn with_google_cloud_adc<P: AsRef<str>, L: AsRef<str>>(
956 project_id: P,
957 location: L,
958 ) -> Result<Self, Error> {
959 Self::with_google_cloud_adc_model(project_id, location, Model::default())
960 }
961
962 #[cfg(feature = "vertex")]
964 pub fn with_google_cloud_adc_model<P: AsRef<str>, L: AsRef<str>, M: Into<Model>>(
965 project_id: P,
966 location: L,
967 model: M,
968 ) -> Result<Self, Error> {
969 GeminiBuilder::new("")
970 .with_model(model)
971 .with_google_cloud(project_id.as_ref(), location.as_ref())
972 .with_google_cloud_adc()?
973 .build()
974 }
975
976 #[cfg(feature = "vertex")]
978 pub fn with_google_cloud_wif_json<P: AsRef<str>, L: AsRef<str>, M: Into<Model>>(
979 wif_json: &str,
980 project_id: P,
981 location: L,
982 model: M,
983 ) -> Result<Self, Error> {
984 GeminiBuilder::new("")
985 .with_model(model)
986 .with_google_cloud(project_id.as_ref(), location.as_ref())
987 .with_google_cloud_wif_json(wif_json)?
988 .build()
989 }
990
991 #[cfg(feature = "vertex")]
993 pub fn with_service_account_json(service_account_json: &str) -> Result<Self, Error> {
994 Self::with_service_account_json_model(service_account_json, Model::default())
995 }
996
997 #[cfg(feature = "vertex")]
999 pub fn with_service_account_json_model<M: Into<Model>>(
1000 service_account_json: &str,
1001 model: M,
1002 ) -> Result<Self, Error> {
1003 let project_id = extract_service_account_project_id(service_account_json)?;
1004 GeminiBuilder::new("")
1005 .with_model(model)
1006 .with_service_account_json(service_account_json)?
1007 .with_google_cloud(project_id, "us-central1")
1008 .build()
1009 }
1010
1011 #[cfg(feature = "vertex")]
1013 pub fn with_google_cloud_service_account_json<M: Into<Model>>(
1014 service_account_json: &str,
1015 project_id: &str,
1016 location: &str,
1017 model: M,
1018 ) -> Result<Self, Error> {
1019 GeminiBuilder::new("")
1020 .with_model(model)
1021 .with_service_account_json(service_account_json)?
1022 .with_google_cloud(project_id, location)
1023 .build()
1024 }
1025
1026 pub fn with_model_and_base_url<K: AsRef<str>, M: Into<Model>>(
1028 api_key: K,
1029 model: M,
1030 base_url: Url,
1031 ) -> Result<Self, Error> {
1032 let model = model.into();
1033 let studio =
1034 backend::studio::StudioBackend::new(api_key.as_ref(), model.clone(), base_url)?;
1035 Ok(Self { client: Arc::new(GeminiClient::with_studio(model, studio)) })
1036 }
1037
1038 pub fn generate_content(&self) -> ContentBuilder {
1040 ContentBuilder::new(self.client.clone())
1041 }
1042
1043 pub fn embed_content(&self) -> EmbedBuilder {
1045 EmbedBuilder::new(self.client.clone())
1046 }
1047
1048 pub fn batch_generate_content(&self) -> BatchBuilder {
1050 BatchBuilder::new(self.client.clone())
1051 }
1052
1053 pub fn get_batch(&self, name: &str) -> BatchHandle {
1055 BatchHandle::new(name.to_string(), self.client.clone())
1056 }
1057
1058 pub fn list_batches(
1060 &self,
1061 page_size: impl Into<Option<u32>>,
1062 ) -> impl Stream<Item = Result<BatchOperation, Error>> + Send {
1063 let client = self.client.clone();
1064 let page_size = page_size.into();
1065 async_stream::try_stream! {
1066 let mut page_token: Option<String> = None;
1067 loop {
1068 let response = client
1069 .list_batch_operations(page_size, page_token.clone())
1070 .await?;
1071
1072 for operation in response.operations {
1073 yield operation;
1074 }
1075
1076 if let Some(next_page_token) = response.next_page_token {
1077 page_token = Some(next_page_token);
1078 } else {
1079 break;
1080 }
1081 }
1082 }
1083 }
1084
1085 pub fn create_cache(&self) -> CacheBuilder {
1087 CacheBuilder::new(self.client.clone())
1088 }
1089
1090 pub fn get_cached_content(&self, name: &str) -> CachedContentHandle {
1092 CachedContentHandle::new(name.to_string(), self.client.clone())
1093 }
1094
1095 pub fn list_cached_contents(
1097 &self,
1098 page_size: impl Into<Option<i32>>,
1099 ) -> impl Stream<Item = Result<CachedContentSummary, Error>> + Send {
1100 let client = self.client.clone();
1101 let page_size = page_size.into();
1102 async_stream::try_stream! {
1103 let mut page_token: Option<String> = None;
1104 loop {
1105 let response = client
1106 .list_cached_contents(page_size, page_token.clone())
1107 .await?;
1108
1109 for cached_content in response.cached_contents {
1110 yield cached_content;
1111 }
1112
1113 if let Some(next_page_token) = response.next_page_token {
1114 page_token = Some(next_page_token);
1115 } else {
1116 break;
1117 }
1118 }
1119 }
1120 }
1121
1122 pub fn create_file<B: Into<Vec<u8>>>(&self, bytes: B) -> crate::files::builder::FileBuilder {
1124 crate::files::builder::FileBuilder::new(self.client.clone(), bytes)
1125 }
1126
1127 pub async fn get_file(&self, name: &str) -> Result<FileHandle, Error> {
1129 let file = self.client.get_file(name).await?;
1130 Ok(FileHandle::new(self.client.clone(), file))
1131 }
1132
1133 pub fn list_files(
1135 &self,
1136 page_size: impl Into<Option<u32>>,
1137 ) -> impl Stream<Item = Result<FileHandle, Error>> + Send {
1138 let client = self.client.clone();
1139 let page_size = page_size.into();
1140 async_stream::try_stream! {
1141 let mut page_token: Option<String> = None;
1142 loop {
1143 let response = client
1144 .list_files(page_size, page_token.clone())
1145 .await?;
1146
1147 for file in response.files {
1148 yield FileHandle::new(client.clone(), file);
1149 }
1150
1151 if let Some(next_page_token) = response.next_page_token {
1152 page_token = Some(next_page_token);
1153 } else {
1154 break;
1155 }
1156 }
1157 }
1158 }
1159
1160 pub fn list_models(
1181 &self,
1182 page_size: impl Into<Option<u32>>,
1183 ) -> impl Stream<Item = Result<crate::model_info::ModelInfo, Error>> + Send {
1184 let client = self.client.clone();
1185 let page_size = page_size.into();
1186 async_stream::try_stream! {
1187 let mut page_token: Option<String> = None;
1188 loop {
1189 let response = client
1190 .list_models(page_size, page_token.clone())
1191 .await?;
1192
1193 for model in response.models {
1194 yield model;
1195 }
1196
1197 if let Some(next_page_token) = response.next_page_token {
1198 page_token = Some(next_page_token);
1199 } else {
1200 break;
1201 }
1202 }
1203 }
1204 }
1205
1206 pub async fn get_model(&self, name: &str) -> Result<crate::model_info::ModelInfo, Error> {
1220 self.client.get_model(name).await
1221 }
1222}
1223
1224#[cfg(test)]
1229#[cfg(feature = "vertex")]
1230mod client_tests {
1231 use super::{Error, GoogleCloudConfig, extract_service_account_project_id};
1232 use crate::backend::vertex::VertexBackend;
1233
1234 #[test]
1235 fn extract_service_account_project_id_reads_project_id() {
1236 let json = r#"{
1237 "type": "service_account",
1238 "project_id": "test-project-123",
1239 "private_key_id": "key-id"
1240 }"#;
1241
1242 let project_id = extract_service_account_project_id(json).expect("project id should parse");
1243 assert_eq!(project_id, "test-project-123");
1244 }
1245
1246 #[test]
1247 fn extract_service_account_project_id_missing_field_errors() {
1248 let json = r#"{
1249 "type": "service_account",
1250 "private_key_id": "key-id"
1251 }"#;
1252
1253 let err =
1254 extract_service_account_project_id(json).expect_err("missing project_id should fail");
1255 assert!(matches!(err, Error::MissingGoogleCloudProjectId));
1256 }
1257
1258 #[test]
1259 fn extract_service_account_project_id_invalid_json_errors() {
1260 let err =
1261 extract_service_account_project_id("not-json").expect_err("invalid json should fail");
1262 assert!(matches!(err, Error::GoogleCloudCredentialParse { .. }));
1263 }
1264
1265 #[test]
1266 fn vertex_transport_error_detection_matches_http2_failure() {
1267 assert!(VertexBackend::is_transport_error(
1268 "the transport reports an error: client error (SendRequest): http2 error"
1269 ));
1270 assert!(!VertexBackend::is_transport_error("permission denied"));
1271 }
1272
1273 #[test]
1274 fn vertex_regional_endpoint_uses_location_prefix() {
1275 let config = GoogleCloudConfig {
1276 project_id: "my-project".to_string(),
1277 location: "us-central1".to_string(),
1278 };
1279 assert_eq!(config.endpoint(), "https://us-central1-aiplatform.googleapis.com");
1280 }
1281
1282 #[test]
1283 fn vertex_global_endpoint_omits_location_prefix() {
1284 let config = GoogleCloudConfig {
1285 project_id: "my-project".to_string(),
1286 location: "global".to_string(),
1287 };
1288 assert_eq!(config.endpoint(), "https://aiplatform.googleapis.com");
1289 }
1290
1291 #[test]
1292 fn vertex_other_regional_endpoint_formats_correctly() {
1293 let config = GoogleCloudConfig {
1294 project_id: "my-project".to_string(),
1295 location: "europe-west4".to_string(),
1296 };
1297 assert_eq!(config.endpoint(), "https://europe-west4-aiplatform.googleapis.com");
1298 }
1299}