1use crate::{
2 batch::{BatchBuilder, BatchHandle},
3 cache::{CacheBuilder, CachedContentHandle},
4 embedding::{
5 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6 EmbedBuilder, EmbedContentRequest,
7 },
8 files::{
9 handle::FileHandle,
10 model::{File, ListFilesResponse},
11 },
12 generation::{ContentBuilder, GenerateContentRequest, GenerationResponse},
13};
14use eventsource_stream::{EventStreamError, Eventsource};
15use futures::{Stream, StreamExt, TryStreamExt};
16use mime::Mime;
17use reqwest::{
18 header::{HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue},
19 Client, ClientBuilder, Response,
20};
21use serde::de::DeserializeOwned;
22use serde_json::json;
23use snafu::{ResultExt, Snafu};
24use std::{
25 fmt::{self, Formatter},
26 sync::{Arc, LazyLock},
27};
28use url::Url;
29
30use crate::batch::model::*;
31use crate::cache::model::*;
32
33static DEFAULT_BASE_URL: LazyLock<Url> = LazyLock::new(|| {
34 Url::parse("https://generativelanguage.googleapis.com/v1beta/")
35 .expect("unreachable error: failed to parse default base URL")
36});
37
38#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
39pub enum Model {
40 #[default]
41 Gemini25Flash,
42 Gemini25FlashLite,
43 Gemini25Pro,
44 TextEmbedding004,
45 Custom(String),
46}
47
48impl Model {
49 pub fn as_str(&self) -> &str {
50 match self {
51 Model::Gemini25Flash => "models/gemini-2.5-flash",
52 Model::Gemini25FlashLite => "models/gemini-2.5-flash-lite",
53 Model::Gemini25Pro => "models/gemini-2.5-pro",
54 Model::TextEmbedding004 => "models/text-embedding-004",
55 Model::Custom(model) => model,
56 }
57 }
58}
59
60impl From<String> for Model {
61 fn from(model: String) -> Self {
62 Self::Custom(model)
63 }
64}
65
66impl fmt::Display for Model {
67 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68 match self {
69 Model::Gemini25Flash => write!(f, "models/gemini-2.5-flash"),
70 Model::Gemini25FlashLite => write!(f, "models/gemini-2.5-flash-lite"),
71 Model::Gemini25Pro => write!(f, "models/gemini-2.5-pro"),
72 Model::TextEmbedding004 => write!(f, "models/text-embedding-004"),
73 Model::Custom(model) => write!(f, "{}", model),
74 }
75 }
76}
77
78#[derive(Debug, Snafu)]
79#[snafu(visibility(pub))]
80pub enum Error {
81 #[snafu(display("failed to parse API key"))]
82 InvalidApiKey { source: InvalidHeaderValue },
83
84 #[snafu(display("failed to construct URL (probably incorrect model name): {suffix}"))]
85 ConstructUrl {
86 source: url::ParseError,
87 suffix: String,
88 },
89
90 #[snafu(display("failed to perform request to '{url}'"))]
91 PerformRequest { source: reqwest::Error, url: Url },
92
93 #[snafu(display(
94 "bad response from server; code {code}; description: {}",
95 description.as_deref().unwrap_or("none")
96 ))]
97 BadResponse {
98 code: u16,
100 description: Option<String>,
102 },
103
104 #[snafu(display("failed to obtain stream SSE part"))]
105 BadPart {
106 source: EventStreamError<reqwest::Error>,
107 },
108
109 #[snafu(display("failed to deserialize JSON response"))]
110 Deserialize { source: serde_json::Error },
111
112 #[snafu(display("failed to generate content"))]
113 DecodeResponse { source: reqwest::Error },
114
115 #[snafu(display("failed to parse URL"))]
116 UrlParse { source: url::ParseError },
117
118 #[snafu(display("I/O error during file operations"))]
119 Io { source: std::io::Error },
120}
121
122pub(crate) struct GeminiClient {
124 http_client: Client,
125 pub model: Model,
126 base_url: Url,
127}
128
129impl GeminiClient {
130 fn with_base_url<K: AsRef<str>, M: Into<Model>>(
132 api_key: K,
133 model: M,
134 base_url: Url,
135 ) -> Result<Self, Error> {
136 let headers = HeaderMap::from_iter([(
137 HeaderName::from_static("x-goog-api-key"),
138 HeaderValue::from_str(api_key.as_ref()).context(InvalidApiKeySnafu)?,
139 )]);
140
141 let http_client = ClientBuilder::new()
142 .default_headers(headers)
143 .build()
144 .expect("all parameters must be valid");
145
146 Ok(Self {
147 http_client,
148 model: model.into(),
149 base_url,
150 })
151 }
152
153 async fn check_response(response: Response) -> Result<Response, Error> {
155 let status = response.status();
156 if !status.is_success() {
157 let description = response.text().await.ok();
158 BadResponseSnafu {
159 code: status.as_u16(),
160 description,
161 }
162 .fail()
163 } else {
164 Ok(response)
165 }
166 }
167
168 pub(crate) async fn generate_content_raw(
170 &self,
171 request: GenerateContentRequest,
172 ) -> Result<GenerationResponse, Error> {
173 let url = self.build_url("generateContent")?;
174
175 let response = self
176 .http_client
177 .post(url.clone())
178 .json(&request)
179 .send()
180 .await
181 .context(PerformRequestSnafu { url })?;
182
183 Self::check_response(response)
184 .await?
185 .json()
186 .await
187 .context(DecodeResponseSnafu)
188 }
189
190 pub(crate) async fn generate_content_stream(
192 &self,
193 request: GenerateContentRequest,
194 ) -> Result<impl TryStreamExt<Ok = GenerationResponse, Error = Error> + Send, Error> {
195 let mut url = self.build_url("streamGenerateContent")?;
196 url.query_pairs_mut().append_pair("alt", "sse");
197
198 let response = self
199 .http_client
200 .post(url.clone())
201 .json(&request)
202 .send()
203 .await
204 .context(PerformRequestSnafu { url })?;
205
206 Ok(Self::check_response(response)
207 .await?
208 .bytes_stream()
209 .eventsource()
210 .map(|event| event.context(BadPartSnafu))
211 .map_ok(|event| {
212 serde_json::from_str::<GenerationResponse>(&event.data).context(DeserializeSnafu)
213 })
214 .map(|r| r.flatten()))
215 }
216
217 pub(crate) async fn embed_content(
219 &self,
220 request: EmbedContentRequest,
221 ) -> Result<ContentEmbeddingResponse, Error> {
222 self.post_json(request, "embedContent").await
223 }
224
225 pub(crate) async fn embed_content_batch(
227 &self,
228 request: BatchEmbedContentsRequest,
229 ) -> Result<BatchContentEmbeddingResponse, Error> {
230 self.post_json(request, "batchEmbedContents").await
231 }
232
233 pub(crate) async fn batch_generate_content_sync(
235 &self,
236 request: BatchGenerateContentRequest,
237 ) -> Result<BatchGenerateContentResponse, Error> {
238 let value = self.post_json(request, "batchGenerateContent").await?;
239 serde_json::from_value(value).context(DeserializeSnafu)
240 }
241
242 pub(crate) async fn get_batch_operation<T: serde::de::DeserializeOwned>(
244 &self,
245 name: &str,
246 ) -> Result<T, Error> {
247 let url = self.build_batch_url(name, None)?;
248
249 let response = self
250 .http_client
251 .get(url.clone())
252 .send()
253 .await
254 .context(PerformRequestSnafu { url })?;
255
256 Self::check_response(response)
257 .await?
258 .json()
259 .await
260 .context(DecodeResponseSnafu)
261 }
262
263 pub(crate) async fn list_batch_operations(
265 &self,
266 page_size: Option<u32>,
267 page_token: Option<String>,
268 ) -> Result<ListBatchesResponse, Error> {
269 let mut url = self.build_batch_url("batches", None)?;
270
271 if let Some(size) = page_size {
272 url.query_pairs_mut()
273 .append_pair("pageSize", &size.to_string());
274 }
275 if let Some(token) = page_token {
276 url.query_pairs_mut().append_pair("pageToken", &token);
277 }
278
279 let response = self
280 .http_client
281 .get(url.clone())
282 .send()
283 .await
284 .context(PerformRequestSnafu { url })?;
285
286 Self::check_response(response)
287 .await?
288 .json()
289 .await
290 .context(DecodeResponseSnafu)
291 }
292
293 pub(crate) async fn list_files(
295 &self,
296 page_size: Option<u32>,
297 page_token: Option<String>,
298 ) -> Result<ListFilesResponse, Error> {
299 let mut url = self.build_files_url(None)?;
300
301 if let Some(size) = page_size {
302 url.query_pairs_mut()
303 .append_pair("pageSize", &size.to_string());
304 }
305 if let Some(token) = page_token {
306 url.query_pairs_mut().append_pair("pageToken", &token);
307 }
308
309 let response = self
310 .http_client
311 .get(url.clone())
312 .send()
313 .await
314 .context(PerformRequestSnafu { url })?;
315
316 Self::check_response(response)
317 .await?
318 .json()
319 .await
320 .context(DecodeResponseSnafu)
321 }
322
323 pub(crate) async fn cancel_batch_operation(&self, name: &str) -> Result<(), Error> {
325 let url = self.build_batch_url(name, Some("cancel"))?;
326 let response = self
327 .http_client
328 .post(url.clone())
329 .json(&json!({}))
330 .send()
331 .await
332 .context(PerformRequestSnafu { url })?;
333
334 Self::check_response(response).await?;
335 Ok(())
336 }
337
338 pub(crate) async fn delete_batch_operation(&self, name: &str) -> Result<(), Error> {
340 let url = self.build_batch_url(name, None)?;
341 let response = self
342 .http_client
343 .delete(url.clone())
344 .send()
345 .await
346 .context(PerformRequestSnafu { url })?;
347
348 Self::check_response(response).await?;
349 Ok(())
350 }
351
352 pub(crate) async fn upload_file(
354 &self,
355 display_name: Option<String>,
356 file_bytes: Vec<u8>,
357 mime_type: Mime,
358 ) -> Result<File, Error> {
359 let initiate_url =
362 self.base_url
363 .join("/upload/v1beta/files")
364 .context(ConstructUrlSnafu {
365 suffix: "/upload/v1beta/files".to_string(),
366 })?;
367
368 let response = self
369 .http_client
370 .post(initiate_url.clone())
371 .header("X-Goog-Upload-Protocol", "resumable")
372 .header("X-Goog-Upload-Command", "start")
373 .header(
374 "X-Goog-Upload-Header-Content-Length",
375 file_bytes.len().to_string(),
376 )
377 .header("X-Goog-Upload-Header-Content-Type", mime_type.to_string())
378 .json(&json!({"file": {"displayName": display_name}}))
379 .send()
380 .await
381 .context(PerformRequestSnafu {
382 url: initiate_url.clone(),
383 })?;
384
385 let checked_response = Self::check_response(response).await?;
386
387 let upload_url = checked_response
388 .headers()
389 .get("X-Goog-Upload-URL")
390 .and_then(|h| h.to_str().ok())
391 .ok_or_else(|| Error::BadResponse {
392 code: 500,
393 description: Some("Missing upload URL in response".to_string()),
394 })?;
395
396 let upload_response = self
398 .http_client
399 .post(upload_url)
400 .header("X-Goog-Upload-Command", "upload, finalize")
401 .header("X-Goog-Upload-Offset", "0")
402 .body(file_bytes)
403 .send()
404 .await
405 .map_err(|e| Error::PerformRequest {
406 source: e,
407 url: Url::parse(upload_url).unwrap_or_else(|_| initiate_url.clone()),
408 })?;
409
410 let final_response = Self::check_response(upload_response).await?;
411
412 #[derive(serde::Deserialize)]
413 struct UploadResponse {
414 file: File,
415 }
416
417 let upload_response: UploadResponse =
418 final_response.json().await.context(DecodeResponseSnafu)?;
419 Ok(upload_response.file)
420 }
421
422 pub(crate) async fn get_file(&self, name: &str) -> Result<File, Error> {
424 let url = self.build_files_url(Some(name))?;
425 let response = self
426 .http_client
427 .get(url.clone())
428 .send()
429 .await
430 .context(PerformRequestSnafu { url })?;
431
432 Self::check_response(response)
433 .await?
434 .json()
435 .await
436 .context(DecodeResponseSnafu)
437 }
438
439 pub(crate) async fn delete_file(&self, name: &str) -> Result<(), Error> {
441 let url = self.build_files_url(Some(name))?;
442 let response = self
443 .http_client
444 .delete(url.clone())
445 .send()
446 .await
447 .context(PerformRequestSnafu { url })?;
448
449 Self::check_response(response).await?;
450 Ok(())
451 }
452
453 pub(crate) async fn download_file(&self, name: &str) -> Result<Vec<u8>, Error> {
454 let mut url = self
455 .base_url
456 .join(&format!("/download/v1beta/{name}:download"))
457 .context(ConstructUrlSnafu {
458 suffix: format!("/download/v1beta/{name}:download"),
459 })?;
460 url.query_pairs_mut().append_pair("alt", "media");
461
462 let response = self
463 .http_client
464 .get(url.clone())
465 .send()
466 .await
467 .context(PerformRequestSnafu { url: url.clone() })?;
468
469 Self::check_response(response)
470 .await?
471 .bytes()
472 .await
473 .context(PerformRequestSnafu { url })
474 .map(|b| b.to_vec())
475 }
476
477 async fn post_json<I: serde::Serialize, O: DeserializeOwned>(
479 &self,
480 request: I,
481 endpoint: &str,
482 ) -> Result<O, Error> {
483 let url = self.build_url(endpoint)?;
484
485 let response = self
486 .http_client
487 .post(url.clone())
488 .json(&request)
489 .send()
490 .await
491 .context(PerformRequestSnafu { url })?;
492
493 Self::check_response(response)
494 .await?
495 .json::<O>()
496 .await
497 .context(DecodeResponseSnafu)
498 }
499
500 pub(crate) async fn create_cached_content(
502 &self,
503 cached_content: CreateCachedContentRequest,
504 ) -> Result<CachedContent, Error> {
505 let url = self.build_cache_url(None)?;
506 let response = self
507 .http_client
508 .post(url.clone())
509 .json(&cached_content)
510 .send()
511 .await
512 .context(PerformRequestSnafu { url })?;
513
514 Self::check_response(response)
515 .await?
516 .json::<CachedContent>()
517 .await
518 .context(DecodeResponseSnafu)
519 }
520
521 pub(crate) async fn get_cached_content(&self, name: &str) -> Result<CachedContent, Error> {
523 let url = self.build_cache_url(Some(name))?;
524 let response = self
525 .http_client
526 .get(url.clone())
527 .send()
528 .await
529 .context(PerformRequestSnafu { url })?;
530
531 Self::check_response(response)
532 .await?
533 .json::<CachedContent>()
534 .await
535 .context(DecodeResponseSnafu)
536 }
537
538 pub(crate) async fn update_cached_content(
540 &self,
541 name: &str,
542 expiration: CacheExpirationRequest,
543 ) -> Result<CachedContent, Error> {
544 let url = self.build_cache_url(Some(name))?;
545
546 let update_payload = match expiration {
548 CacheExpirationRequest::Ttl { ttl } => json!({ "ttl": ttl }),
549 CacheExpirationRequest::ExpireTime { expire_time } => {
550 json!({ "expireTime": expire_time.format(&time::format_description::well_known::Rfc3339).unwrap() })
551 }
552 };
553
554 let response = self
555 .http_client
556 .patch(url.clone())
557 .json(&update_payload)
558 .send()
559 .await
560 .context(PerformRequestSnafu { url })?;
561
562 Self::check_response(response)
563 .await?
564 .json::<CachedContent>()
565 .await
566 .context(DecodeResponseSnafu)
567 }
568
569 pub(crate) async fn delete_cached_content(
571 &self,
572 name: &str,
573 ) -> Result<DeleteCachedContentResponse, Error> {
574 let url = self.build_cache_url(Some(name))?;
575 let response = self
576 .http_client
577 .delete(url.clone())
578 .send()
579 .await
580 .context(PerformRequestSnafu { url })?;
581
582 if response.status().is_success() {
584 Ok(DeleteCachedContentResponse {
585 success: Some(true),
586 })
587 } else {
588 Self::check_response(response)
589 .await?
590 .json::<DeleteCachedContentResponse>()
591 .await
592 .context(DecodeResponseSnafu)
593 }
594 }
595
596 pub(crate) async fn list_cached_contents(
598 &self,
599 page_size: Option<i32>,
600 page_token: Option<String>,
601 ) -> Result<ListCachedContentsResponse, Error> {
602 let mut url = self.build_cache_url(None)?;
603
604 if let Some(size) = page_size {
605 url.query_pairs_mut()
606 .append_pair("pageSize", &size.to_string());
607 }
608 if let Some(token) = page_token {
609 url.query_pairs_mut().append_pair("pageToken", &token);
610 }
611
612 let response = self
613 .http_client
614 .get(url.clone())
615 .send()
616 .await
617 .context(PerformRequestSnafu { url })?;
618
619 Self::check_response(response)
620 .await?
621 .json::<ListCachedContentsResponse>()
622 .await
623 .context(DecodeResponseSnafu)
624 }
625
626 fn build_url(&self, endpoint: &str) -> Result<Url, Error> {
628 let url = self.base_url.clone();
629 let suffix = format!("{}:{endpoint}", self.model);
630 url.join(&suffix).context(ConstructUrlSnafu { suffix })
631 }
632
633 fn build_batch_url(&self, name: &str, action: Option<&str>) -> Result<Url, Error> {
635 let suffix = action
636 .map(|a| format!("{name}:{a}"))
637 .unwrap_or_else(|| name.to_string());
638
639 let url = self.base_url.clone();
640 url.join(&suffix).context(ConstructUrlSnafu { suffix })
641 }
642
643 fn build_files_url(&self, name: Option<&str>) -> Result<Url, Error> {
645 let suffix = name
646 .map(|n| format!("files/{}", n.strip_prefix("files/").unwrap_or(n)))
647 .unwrap_or_else(|| "files".to_string());
648
649 self.base_url
650 .join(&suffix)
651 .context(ConstructUrlSnafu { suffix })
652 }
653
654 fn build_cache_url(&self, name: Option<&str>) -> Result<Url, Error> {
656 let suffix = name
657 .map(|n| {
658 if n.starts_with("cachedContents/") {
659 n.to_string()
660 } else {
661 format!("cachedContents/{}", n)
662 }
663 })
664 .unwrap_or_else(|| "cachedContents".to_string());
665
666 self.base_url
667 .join(&suffix)
668 .context(ConstructUrlSnafu { suffix })
669 }
670}
671
672#[derive(Clone)]
674pub struct Gemini {
675 client: Arc<GeminiClient>,
676}
677
678impl Gemini {
679 pub fn new<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
681 Self::with_model(api_key, Model::default())
682 }
683
684 pub fn pro<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
686 Self::with_model(api_key, Model::Gemini25Pro)
687 }
688
689 pub fn with_model<K: AsRef<str>, M: Into<Model>>(api_key: K, model: M) -> Result<Self, Error> {
691 Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.clone())
692 }
693
694 pub fn with_base_url<K: AsRef<str>>(api_key: K, base_url: Url) -> Result<Self, Error> {
696 Self::with_model_and_base_url(api_key, Model::default(), base_url)
697 }
698
699 pub fn with_model_and_base_url<K: AsRef<str>, M: Into<Model>>(
701 api_key: K,
702 model: M,
703 base_url: Url,
704 ) -> Result<Self, Error> {
705 let client = GeminiClient::with_base_url(api_key, model.into(), base_url)?;
706 Ok(Self {
707 client: Arc::new(client),
708 })
709 }
710
711 pub fn generate_content(&self) -> ContentBuilder {
713 ContentBuilder::new(self.client.clone())
714 }
715
716 pub fn embed_content(&self) -> EmbedBuilder {
718 EmbedBuilder::new(self.client.clone())
719 }
720
721 pub fn batch_generate_content(&self) -> BatchBuilder {
723 BatchBuilder::new(self.client.clone())
724 }
725
726 pub fn get_batch(&self, name: &str) -> BatchHandle {
728 BatchHandle::new(name.to_string(), self.client.clone())
729 }
730
731 pub fn list_batches(
735 &self,
736 page_size: impl Into<Option<u32>>,
737 ) -> impl Stream<Item = Result<BatchOperation, Error>> + Send {
738 let client = self.client.clone();
739 let page_size = page_size.into();
740 async_stream::try_stream! {
741 let mut page_token: Option<String> = None;
742 loop {
743 let response = client
744 .list_batch_operations(page_size, page_token.clone())
745 .await?;
746
747 for operation in response.operations {
748 yield operation;
749 }
750
751 if let Some(next_page_token) = response.next_page_token {
752 page_token = Some(next_page_token);
753 } else {
754 break;
755 }
756 }
757 }
758 }
759
760 pub fn create_cache(&self) -> CacheBuilder {
762 CacheBuilder::new(self.client.clone())
763 }
764
765 pub fn get_cached_content(&self, name: &str) -> CachedContentHandle {
767 CachedContentHandle::new(name.to_string(), self.client.clone())
768 }
769
770 pub fn list_cached_contents(
774 &self,
775 page_size: impl Into<Option<i32>>,
776 ) -> impl Stream<Item = Result<CachedContentSummary, Error>> + Send {
777 let client = self.client.clone();
778 let page_size = page_size.into();
779 async_stream::try_stream! {
780 let mut page_token: Option<String> = None;
781 loop {
782 let response = client
783 .list_cached_contents(page_size, page_token.clone())
784 .await?;
785
786 for cached_content in response.cached_contents {
787 yield cached_content;
788 }
789
790 if let Some(next_page_token) = response.next_page_token {
791 page_token = Some(next_page_token);
792 } else {
793 break;
794 }
795 }
796 }
797 }
798
799 pub fn create_file<B: Into<Vec<u8>>>(&self, bytes: B) -> crate::files::builder::FileBuilder {
801 crate::files::builder::FileBuilder::new(self.client.clone(), bytes)
802 }
803
804 pub async fn get_file(&self, name: &str) -> Result<FileHandle, Error> {
806 let file = self.client.get_file(name).await?;
807 Ok(FileHandle::new(self.client.clone(), file))
808 }
809
810 pub fn list_files(
814 &self,
815 page_size: impl Into<Option<u32>>,
816 ) -> impl Stream<Item = Result<FileHandle, Error>> + Send {
817 let client = self.client.clone();
818 let page_size = page_size.into();
819 async_stream::try_stream! {
820 let mut page_token: Option<String> = None;
821 loop {
822 let response = client
823 .list_files(page_size, page_token.clone())
824 .await?;
825
826 for file in response.files {
827 yield FileHandle::new(client.clone(), file);
828 }
829
830 if let Some(next_page_token) = response.next_page_token {
831 page_token = Some(next_page_token);
832 } else {
833 break;
834 }
835 }
836 }
837 }
838}