gemini_rust/
client.rs

1use crate::{
2    batch::{BatchBuilder, BatchHandle},
3    cache::{CacheBuilder, CachedContentHandle},
4    embedding::{
5        BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6        EmbedBuilder, EmbedContentRequest,
7    },
8    files::{
9        handle::FileHandle,
10        model::{File, ListFilesResponse},
11    },
12    generation::{ContentBuilder, GenerateContentRequest, GenerationResponse},
13};
14use eventsource_stream::{EventStreamError, Eventsource};
15use futures::{Stream, StreamExt, TryStreamExt};
16use mime::Mime;
17use reqwest::{
18    header::{HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue},
19    Client, ClientBuilder, RequestBuilder, Response,
20};
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23use snafu::{OptionExt, ResultExt, Snafu};
24use std::{
25    fmt::{self, Formatter},
26    pin::Pin,
27    sync::{Arc, LazyLock},
28};
29use tracing::{instrument, Level, Span};
30use url::Url;
31
32use crate::batch::model::*;
33use crate::cache::model::*;
34
35/// Type alias for streaming generation responses
36///
37/// A pinned, boxed stream that yields `GenerationResponse` chunks as they arrive
38/// from the API. Used for streaming content generation to receive partial results
39/// before the complete response is ready.
40pub type GenerationStream = Pin<Box<dyn Stream<Item = Result<GenerationResponse, Error>> + Send>>;
41
42static DEFAULT_BASE_URL: LazyLock<Url> = LazyLock::new(|| {
43    Url::parse("https://generativelanguage.googleapis.com/v1beta/")
44        .expect("unreachable error: failed to parse default base URL")
45});
46
47#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
48pub enum Model {
49    #[default]
50    #[serde(rename = "models/gemini-2.5-flash")]
51    Gemini25Flash,
52    #[serde(rename = "models/gemini-2.5-flash-lite")]
53    Gemini25FlashLite,
54    #[serde(rename = "models/gemini-2.5-flash-image")]
55    Gemini25FlashImage,
56    #[serde(rename = "models/gemini-2.5-pro")]
57    Gemini25Pro,
58    #[serde(rename = "models/gemini-3-flash-preview")]
59    Gemini3Flash,
60    #[serde(rename = "models/gemini-3-pro-preview")]
61    Gemini3Pro,
62    #[serde(rename = "models/gemini-3-pro-image-preview")]
63    Gemini3ProImage,
64    #[serde(rename = "models/text-embedding-004")]
65    TextEmbedding004,
66    #[serde(untagged)]
67    Custom(String),
68}
69
70impl Model {
71    pub fn as_str(&self) -> &str {
72        match self {
73            Model::Gemini25Flash => "models/gemini-2.5-flash",
74            Model::Gemini25FlashLite => "models/gemini-2.5-flash-lite",
75            Model::Gemini25FlashImage => "models/gemini-2.5-flash-image",
76            Model::Gemini25Pro => "models/gemini-2.5-pro",
77            Model::Gemini3Flash => "models/gemini-3-flash-preview",
78            Model::Gemini3Pro => "models/gemini-3-pro-preview",
79            Model::Gemini3ProImage => "models/gemini-3-pro-image-preview",
80            Model::TextEmbedding004 => "models/text-embedding-004",
81            Model::Custom(model) => model,
82        }
83    }
84}
85
86impl From<String> for Model {
87    fn from(model: String) -> Self {
88        Self::Custom(model)
89    }
90}
91
92impl fmt::Display for Model {
93    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
94        match self {
95            Model::Gemini25Flash => write!(f, "models/gemini-2.5-flash"),
96            Model::Gemini25FlashLite => write!(f, "models/gemini-2.5-flash-lite"),
97            Model::Gemini25FlashImage => write!(f, "models/gemini-2.5-flash-image"),
98            Model::Gemini25Pro => write!(f, "models/gemini-2.5-pro"),
99            Model::Gemini3Flash => write!(f, "models/gemini-3-flash-preview"),
100            Model::Gemini3Pro => write!(f, "models/gemini-3-pro-preview"),
101            Model::Gemini3ProImage => write!(f, "models/gemini-3-pro-image-preview"),
102            Model::TextEmbedding004 => write!(f, "models/text-embedding-004"),
103            Model::Custom(model) => write!(f, "{model}"),
104        }
105    }
106}
107
108#[derive(Debug, Snafu)]
109#[snafu(visibility(pub))]
110pub enum Error {
111    #[snafu(display("failed to parse API key"))]
112    InvalidApiKey {
113        source: InvalidHeaderValue,
114    },
115
116    #[snafu(display("failed to construct URL (probably incorrect model name): {suffix}"))]
117    ConstructUrl {
118        source: url::ParseError,
119        suffix: String,
120    },
121
122    PerformRequestNew {
123        source: reqwest::Error,
124    },
125
126    #[snafu(display("failed to perform request to '{url}'"))]
127    PerformRequest {
128        source: reqwest::Error,
129        url: Url,
130    },
131
132    #[snafu(display(
133        "bad response from server; code {code}; description: {}",
134        description.as_deref().unwrap_or("none")
135    ))]
136    BadResponse {
137        /// HTTP status code
138        code: u16,
139        /// HTTP error description
140        description: Option<String>,
141    },
142
143    MissingResponseHeader {
144        header: String,
145    },
146
147    #[snafu(display("failed to obtain stream SSE part"))]
148    BadPart {
149        source: EventStreamError<reqwest::Error>,
150    },
151
152    #[snafu(display("failed to deserialize JSON response"))]
153    Deserialize {
154        source: serde_json::Error,
155    },
156
157    #[snafu(display("failed to generate content"))]
158    DecodeResponse {
159        source: reqwest::Error,
160    },
161
162    #[snafu(display("failed to parse URL"))]
163    UrlParse {
164        source: url::ParseError,
165    },
166
167    #[snafu(display("I/O error during file operations"))]
168    Io {
169        source: std::io::Error,
170    },
171}
172
173/// Internal client for making requests to the Gemini API
174pub struct GeminiClient {
175    http_client: Client,
176    pub model: Model,
177    base_url: Url,
178}
179
180impl GeminiClient {
181    /// Create a new client with custom base URL
182    fn with_base_url<K: AsRef<str>, M: Into<Model>>(
183        client_builder: ClientBuilder,
184        api_key: K,
185        model: M,
186        base_url: Url,
187    ) -> Result<Self, Error> {
188        let headers = HeaderMap::from_iter([(
189            HeaderName::from_static("x-goog-api-key"),
190            HeaderValue::from_str(api_key.as_ref()).context(InvalidApiKeySnafu)?,
191        )]);
192
193        let http_client = client_builder
194            .default_headers(headers)
195            .build()
196            .expect("all parameters must be valid");
197
198        Ok(Self {
199            http_client,
200            model: model.into(),
201            base_url,
202        })
203    }
204
205    /// Check the response status code and return an error if it is not successful
206    #[tracing::instrument(skip_all, err)]
207    async fn check_response(response: Response) -> Result<Response, Error> {
208        let status = response.status();
209        if !status.is_success() {
210            let description = response.text().await.ok();
211            BadResponseSnafu {
212                code: status.as_u16(),
213                description,
214            }
215            .fail()
216        } else {
217            Ok(response)
218        }
219    }
220
221    /// Performs an HTTP request to the Gemini API with standardized error handling.
222    ///
223    /// This method provides a generic way to make HTTP requests to the Gemini API with
224    /// consistent error handling, response checking, and deserialization. It handles:
225    /// - Building the HTTP request using a provided builder function
226    /// - Sending the request and handling network errors
227    /// - Checking the response status code for errors
228    /// - Deserializing the response using a provided deserializer function
229    ///
230    /// # Type Parameters
231    /// * `B` - A function that takes a `&Client` and returns a `RequestBuilder`
232    /// * `D` - An async function that takes ownership of a `Response` and returns a `Result<T, Error>`
233    /// * `T` - The type of the deserialized response
234    ///
235    /// # Note
236    /// The `AsyncFn` trait is a standard Rust feature (stabilized in v1.85) and does not
237    /// require any additional imports or feature flags.
238    ///
239    /// # Parameters
240    /// * `builder` - A function that constructs the HTTP request using the client
241    /// * `deserializer` - An async function that processes the response into the desired type
242    ///
243    /// # Examples
244    ///
245    /// Basic HTTP operations:
246    /// ```no_run
247    /// # use gemini_rust::client::*;
248    /// # use reqwest::Response;
249    /// # use url::Url;
250    /// # use serde_json::Value;
251    /// # use snafu::ResultExt;
252    /// # async fn examples(client: &GeminiClient) -> Result<(), Box<dyn std::error::Error>> {
253    /// # let url: Url = "https://example.com".parse()?;
254    /// # let request = Value::Null;
255    ///
256    /// // POST request with JSON payload
257    /// let _response : () = client
258    ///     .perform_request(
259    ///         |c| c.post(url.clone()).json(&request),
260    ///         async |r| r.json().await.context(DecodeResponseSnafu),
261    ///     )
262    ///     .await?;
263    ///
264    /// // GET request with JSON response
265    /// let _response : () = client
266    ///     .perform_request(
267    ///         |c| c.get(url.clone()),
268    ///         async |r| r.json().await.context(DecodeResponseSnafu),
269    ///     )
270    ///     .await?;
271    ///
272    /// // DELETE request with no response body
273    /// let _response = client
274    ///     .perform_request(|c| c.delete(url), async |_r| Ok(()))
275    ///     .await?;
276    /// # Ok(())
277    /// # }
278    /// ```
279    ///
280    /// Request returning a stream:
281    /// ```no_run
282    /// # use gemini_rust::client::*;
283    /// # use reqwest::Response;
284    /// # use url::Url;
285    /// # use serde_json::Value;
286    /// # async fn example(client: &GeminiClient) -> Result<(), Box<dyn std::error::Error>> {
287    /// # let url: Url = "https://example.com".parse()?;
288    /// # let request = Value::Null;
289    /// let _stream = client
290    ///     .perform_request(
291    ///         |c| c.post(url).json(&request),
292    ///         async |r| Ok(r.bytes_stream()),
293    ///     )
294    ///     .await?;
295    /// # Ok(())
296    /// # }
297    /// ```
298    #[tracing::instrument(skip_all)]
299    #[doc(hidden)]
300    pub async fn perform_request<
301        B: FnOnce(&Client) -> RequestBuilder,
302        D: AsyncFn(Response) -> Result<T, Error>,
303        T,
304    >(
305        &self,
306        builder: B,
307        deserializer: D,
308    ) -> Result<T, Error> {
309        let request = builder(&self.http_client);
310        tracing::debug!("request built successfully");
311        let response = request.send().await.context(PerformRequestNewSnafu)?;
312        tracing::debug!("response received successfully");
313        let response = Self::check_response(response).await?;
314        tracing::debug!("response ok");
315        deserializer(response).await
316    }
317
318    /// Perform a GET request and deserialize the JSON response.
319    ///
320    /// This is a convenience wrapper around [`perform_request`](Self::perform_request).
321    #[tracing::instrument(skip(self), fields(request.type = "get", request.url = %url))]
322    async fn get_json<T: serde::de::DeserializeOwned>(&self, url: Url) -> Result<T, Error> {
323        self.perform_request(
324            |c| c.get(url),
325            async |r| r.json().await.context(DecodeResponseSnafu),
326        )
327        .await
328    }
329
330    /// Perform a POST request with JSON body and deserialize the JSON response.
331    ///
332    /// This is a convenience wrapper around [`perform_request`](Self::perform_request).
333    #[tracing::instrument(skip(self, body), fields(request.type = "post", request.url = %url))]
334    async fn post_json<Req: serde::Serialize, Res: serde::de::DeserializeOwned>(
335        &self,
336        url: Url,
337        body: &Req,
338    ) -> Result<Res, Error> {
339        self.perform_request(
340            |c| c.post(url).json(body),
341            async |r| r.json().await.context(DecodeResponseSnafu),
342        )
343        .await
344    }
345
346    /// Generate content
347    #[instrument(skip_all, fields(
348        model,
349        messages.parts.count = request.contents.len(),
350        tools.present = request.tools.is_some(),
351        system.instruction.present = request.system_instruction.is_some(),
352        cached.content.present = request.cached_content.is_some(),
353        usage.prompt_tokens,
354        usage.candidates_tokens,
355        usage.thoughts_tokens,
356        usage.cached_content_tokens,
357        usage.total_tokens,
358    ), ret(level = Level::TRACE), err)]
359    pub(crate) async fn generate_content_raw(
360        &self,
361        request: GenerateContentRequest,
362    ) -> Result<GenerationResponse, Error> {
363        let url = self.build_url("generateContent")?;
364        let response: GenerationResponse = self.post_json(url, &request).await?;
365
366        // Record usage metadata
367        if let Some(usage) = &response.usage_metadata {
368            #[rustfmt::skip]
369            Span::current()
370                .record("usage.prompt_tokens", usage.prompt_token_count)
371                .record("usage.candidates_tokens", usage.candidates_token_count)
372                .record("usage.thoughts_tokens", usage.thoughts_token_count)
373                .record("usage.cached_content_tokens", usage.cached_content_token_count)
374                .record("usage.total_tokens", usage.total_token_count);
375
376            tracing::debug!("generation usage evaluated");
377        }
378
379        Ok(response)
380    }
381
382    /// Generate content with streaming
383    #[instrument(skip_all, fields(
384        model,
385        messages.parts.count = request.contents.len(),
386        tools.present = request.tools.is_some(),
387        system.instruction.present = request.system_instruction.is_some(),
388        cached.content.present = request.cached_content.is_some(),
389    ), err)]
390    pub(crate) async fn generate_content_stream(
391        &self,
392        request: GenerateContentRequest,
393    ) -> Result<GenerationStream, Error> {
394        let mut url = self.build_url("streamGenerateContent")?;
395        url.query_pairs_mut().append_pair("alt", "sse");
396
397        let stream = self
398            .perform_request(
399                |c| c.post(url).json(&request),
400                async |r| Ok(r.bytes_stream()),
401            )
402            .await?;
403
404        Ok(Box::pin(
405            stream
406                .eventsource()
407                .map(|event| event.context(BadPartSnafu))
408                .and_then(|event| async move {
409                    serde_json::from_str::<GenerationResponse>(&event.data)
410                        .context(DeserializeSnafu)
411                }),
412        ))
413    }
414
415    /// Count tokens for content
416    #[instrument(skip_all, fields(
417        model,
418        messages.parts.count = request.contents.len(),
419    ))]
420    pub(crate) async fn count_tokens(
421        &self,
422        request: GenerateContentRequest,
423    ) -> Result<crate::generation::CountTokensResponse, Error> {
424        let url = self.build_url("countTokens")?;
425        // Wrap the request in a "generateContentRequest" field and explicitly add the model.
426        // The countTokens API requires the model to be specified within generateContentRequest.
427        let body = json!({
428            "generateContentRequest": {
429                "model": self.model.as_str(),
430                "contents": request.contents,
431                "generationConfig": request.generation_config,
432                "safetySettings": request.safety_settings,
433                "tools": request.tools,
434                "toolConfig": request.tool_config,
435                "systemInstruction": request.system_instruction,
436                "cachedContent": request.cached_content,
437            }
438        });
439        self.post_json(url, &body).await
440    }
441
442    /// Embed content
443    #[instrument(skip_all, fields(
444        model,
445        task.type = request.task_type.as_ref().map(|t| format!("{t:?}")),
446        task.title = request.title,
447        task.output.dimensionality = request.output_dimensionality,
448    ))]
449    pub(crate) async fn embed_content(
450        &self,
451        request: EmbedContentRequest,
452    ) -> Result<ContentEmbeddingResponse, Error> {
453        let url = self.build_url("embedContent")?;
454        self.post_json(url, &request).await
455    }
456
457    /// Batch Embed content
458    #[instrument(skip_all, fields(batch.size = request.requests.len()))]
459    pub(crate) async fn embed_content_batch(
460        &self,
461        request: BatchEmbedContentsRequest,
462    ) -> Result<BatchContentEmbeddingResponse, Error> {
463        let url = self.build_url("batchEmbedContents")?;
464        self.post_json(url, &request).await
465    }
466
467    /// Batch generate content (synchronous API that returns results immediately)
468    #[instrument(skip_all, fields(
469        batch.display_name = request.batch.display_name,
470        batch.size = request.batch.input_config.batch_size(),
471    ))]
472    pub(crate) async fn batch_generate_content(
473        &self,
474        request: BatchGenerateContentRequest,
475    ) -> Result<BatchGenerateContentResponse, Error> {
476        let url = self.build_url("batchGenerateContent")?;
477        self.post_json(url, &request).await
478    }
479
480    /// Get a batch operation
481    #[instrument(skip_all, fields(
482        operation.name = name,
483    ))]
484    pub(crate) async fn get_batch_operation<T: serde::de::DeserializeOwned>(
485        &self,
486        name: &str,
487    ) -> Result<T, Error> {
488        let url = self.build_batch_url(name, None)?;
489        self.get_json(url).await
490    }
491
492    /// List batch operations
493    #[instrument(skip_all, fields(
494        page.size = page_size,
495        page.token.present = page_token.is_some(),
496    ))]
497    pub(crate) async fn list_batch_operations(
498        &self,
499        page_size: Option<u32>,
500        page_token: Option<String>,
501    ) -> Result<ListBatchesResponse, Error> {
502        let mut url = self.build_batch_url("batches", None)?;
503
504        if let Some(size) = page_size {
505            url.query_pairs_mut()
506                .append_pair("pageSize", &size.to_string());
507        }
508        if let Some(token) = page_token {
509            url.query_pairs_mut().append_pair("pageToken", &token);
510        }
511
512        self.get_json(url).await
513    }
514
515    /// List files
516    #[instrument(skip_all, fields(
517        page.size = page_size,
518        page.token.present = page_token.is_some(),
519    ))]
520    pub(crate) async fn list_files(
521        &self,
522        page_size: Option<u32>,
523        page_token: Option<String>,
524    ) -> Result<ListFilesResponse, Error> {
525        let mut url = self.build_files_url(None)?;
526
527        if let Some(size) = page_size {
528            url.query_pairs_mut()
529                .append_pair("pageSize", &size.to_string());
530        }
531        if let Some(token) = page_token {
532            url.query_pairs_mut().append_pair("pageToken", &token);
533        }
534
535        self.get_json(url).await
536    }
537
538    /// Cancel a batch operation
539    #[instrument(skip_all, fields(
540        operation.name = name,
541    ))]
542    pub(crate) async fn cancel_batch_operation(&self, name: &str) -> Result<(), Error> {
543        let url = self.build_batch_url(name, Some("cancel"))?;
544        self.perform_request(|c| c.post(url).json(&json!({})), async |_r| Ok(()))
545            .await
546    }
547
548    /// Delete a batch operation
549    #[instrument(skip_all, fields(
550        operation.name = name,
551    ))]
552    pub(crate) async fn delete_batch_operation(&self, name: &str) -> Result<(), Error> {
553        let url = self.build_batch_url(name, None)?;
554        self.perform_request(|c| c.delete(url), async |_r| Ok(()))
555            .await
556    }
557
558    async fn create_upload(
559        &self,
560        bytes: usize,
561        display_name: Option<String>,
562        mime_type: Mime,
563    ) -> Result<Url, Error> {
564        let url = self
565            .base_url
566            .join("/upload/v1beta/files")
567            .context(ConstructUrlSnafu {
568                suffix: "/upload/v1beta/files".to_string(),
569            })?;
570
571        self.perform_request(
572            |c| {
573                c.post(url)
574                    .header("X-Goog-Upload-Protocol", "resumable")
575                    .header("X-Goog-Upload-Command", "start")
576                    .header("X-Goog-Upload-Content-Length", bytes.to_string())
577                    .header("X-Goog-Upload-Header-Content-Type", mime_type.to_string())
578                    .json(&json!({"file": {"displayName": display_name}}))
579            },
580            async |r| {
581                r.headers()
582                    .get("X-Goog-Upload-URL")
583                    .context(MissingResponseHeaderSnafu {
584                        header: "X-Goog-Upload-URL",
585                    })
586                    .and_then(|upload_url| {
587                        upload_url
588                            .to_str()
589                            .map(str::to_string)
590                            .map_err(|_| Error::BadResponse {
591                                code: 500,
592                                description: Some("Missing upload URL in response".to_string()),
593                            })
594                    })
595                    .and_then(|url| Url::parse(&url).context(UrlParseSnafu))
596            },
597        )
598        .await
599    }
600
601    /// Upload a file using the resumable upload protocol.
602    #[instrument(skip_all, fields(
603        file.size = file_bytes.len(),
604        mime.type = mime_type.to_string(),
605        file.display_name = display_name.as_deref(),
606    ))]
607    pub(crate) async fn upload_file(
608        &self,
609        display_name: Option<String>,
610        file_bytes: Vec<u8>,
611        mime_type: Mime,
612    ) -> Result<File, Error> {
613        // Step 1: Create resumable upload session
614        let upload_url = self
615            .create_upload(file_bytes.len(), display_name, mime_type)
616            .await?;
617
618        // Step 2: Upload file content
619        let upload_response = self
620            .http_client
621            .post(upload_url.clone())
622            .header("X-Goog-Upload-Command", "upload, finalize")
623            .header("X-Goog-Upload-Offset", "0")
624            .body(file_bytes)
625            .send()
626            .await
627            .map_err(|e| Error::PerformRequest {
628                source: e,
629                url: upload_url,
630            })?;
631
632        let final_response = Self::check_response(upload_response).await?;
633
634        #[derive(serde::Deserialize)]
635        struct UploadResponse {
636            file: File,
637        }
638
639        let upload_response: UploadResponse =
640            final_response.json().await.context(DecodeResponseSnafu)?;
641        Ok(upload_response.file)
642    }
643
644    /// Get a file resource
645    #[instrument(skip_all, fields(
646        file.name = name,
647    ))]
648    pub(crate) async fn get_file(&self, name: &str) -> Result<File, Error> {
649        let url = self.build_files_url(Some(name))?;
650        self.get_json(url).await
651    }
652
653    /// Delete a file resource
654    #[instrument(skip_all, fields(
655        file.name = name,
656    ))]
657    pub(crate) async fn delete_file(&self, name: &str) -> Result<(), Error> {
658        let url = self.build_files_url(Some(name))?;
659        self.perform_request(|c| c.delete(url), async |_r| Ok(()))
660            .await
661    }
662
663    /// Download a file resource
664    #[instrument(skip_all, fields(
665        file.name = name,
666    ))]
667    pub(crate) async fn download_file(&self, name: &str) -> Result<Vec<u8>, Error> {
668        let mut url = self
669            .base_url
670            .join(&format!("/download/v1beta/{name}:download"))
671            .context(ConstructUrlSnafu {
672                suffix: format!("/download/v1beta/{name}:download"),
673            })?;
674        url.query_pairs_mut().append_pair("alt", "media");
675
676        self.perform_request(
677            |c| c.get(url),
678            async |r| {
679                r.bytes()
680                    .await
681                    .context(DecodeResponseSnafu)
682                    .map(|bytes| bytes.to_vec())
683            },
684        )
685        .await
686    }
687
688    /// Create cached content
689    pub(crate) async fn create_cached_content(
690        &self,
691        cached_content: CreateCachedContentRequest,
692    ) -> Result<CachedContent, Error> {
693        let url = self.build_cache_url(None)?;
694        self.post_json(url, &cached_content).await
695    }
696
697    /// Get cached content
698    pub(crate) async fn get_cached_content(&self, name: &str) -> Result<CachedContent, Error> {
699        let url = self.build_cache_url(Some(name))?;
700        self.get_json(url).await
701    }
702
703    /// Update cached content (typically to update TTL)
704    pub(crate) async fn update_cached_content(
705        &self,
706        name: &str,
707        expiration: CacheExpirationRequest,
708    ) -> Result<CachedContent, Error> {
709        let url = self.build_cache_url(Some(name))?;
710
711        // Create a minimal update payload with just the expiration
712        let update_payload = match expiration {
713            CacheExpirationRequest::Ttl { ttl } => json!({ "ttl": ttl }),
714            CacheExpirationRequest::ExpireTime { expire_time } => {
715                json!({ "expireTime": expire_time.format(&time::format_description::well_known::Rfc3339).unwrap() })
716            }
717        };
718
719        self.perform_request(
720            |c| c.patch(url.clone()).json(&update_payload),
721            async |r| r.json().await.context(DecodeResponseSnafu),
722        )
723        .await
724    }
725
726    /// Delete cached content
727    pub(crate) async fn delete_cached_content(&self, name: &str) -> Result<(), Error> {
728        let url = self.build_cache_url(Some(name))?;
729        self.perform_request(|c| c.delete(url.clone()), async |_r| Ok(()))
730            .await
731    }
732
733    /// List cached contents
734    pub(crate) async fn list_cached_contents(
735        &self,
736        page_size: Option<i32>,
737        page_token: Option<String>,
738    ) -> Result<ListCachedContentsResponse, Error> {
739        let mut url = self.build_cache_url(None)?;
740
741        if let Some(size) = page_size {
742            url.query_pairs_mut()
743                .append_pair("pageSize", &size.to_string());
744        }
745        if let Some(token) = page_token {
746            url.query_pairs_mut().append_pair("pageToken", &token);
747        }
748
749        self.get_json(url).await
750    }
751
752    /// Build a URL with the given suffix
753    #[tracing::instrument(skip(self), ret(level = Level::DEBUG))]
754    fn build_url_with_suffix(&self, suffix: &str) -> Result<Url, Error> {
755        self.base_url.join(suffix).context(ConstructUrlSnafu {
756            suffix: suffix.to_string(),
757        })
758    }
759
760    /// Build a URL for the API
761    #[tracing::instrument(skip(self), ret(level = Level::DEBUG))]
762    fn build_url(&self, endpoint: &str) -> Result<Url, Error> {
763        let suffix = format!("{}:{endpoint}", self.model);
764        self.build_url_with_suffix(&suffix)
765    }
766
767    /// Build a URL for a batch operation
768    fn build_batch_url(&self, name: &str, action: Option<&str>) -> Result<Url, Error> {
769        let suffix = action
770            .map(|a| format!("{name}:{a}"))
771            .unwrap_or_else(|| name.to_string());
772        self.build_url_with_suffix(&suffix)
773    }
774
775    /// Build a URL for file operations
776    fn build_files_url(&self, name: Option<&str>) -> Result<Url, Error> {
777        let suffix = name
778            .map(|n| format!("files/{}", n.strip_prefix("files/").unwrap_or(n)))
779            .unwrap_or_else(|| "files".to_string());
780        self.build_url_with_suffix(&suffix)
781    }
782
783    /// Build a URL for cache operations
784    fn build_cache_url(&self, name: Option<&str>) -> Result<Url, Error> {
785        let suffix = name
786            .map(|n| {
787                if n.starts_with("cachedContents/") {
788                    n.to_string()
789                } else {
790                    format!("cachedContents/{n}")
791                }
792            })
793            .unwrap_or_else(|| "cachedContents".to_string());
794        self.build_url_with_suffix(&suffix)
795    }
796}
797
798/// A builder for the `Gemini` client.
799///
800/// # Examples
801///
802/// ## Basic usage
803///
804/// ```no_run
805/// use gemini_rust::{GeminiBuilder, Model};
806///
807/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
808/// let gemini = GeminiBuilder::new("YOUR_API_KEY")
809///     .with_model(Model::Gemini25Pro)
810///     .build()?;
811/// # Ok(())
812/// # }
813/// ```
814///
815/// ## With proxy configuration
816///
817/// ```no_run
818/// use gemini_rust::{GeminiBuilder, Model};
819/// use reqwest::{ClientBuilder, Proxy};
820///
821/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
822/// let proxy = Proxy::https("https://my.proxy")?;
823/// let http_client = ClientBuilder::new().proxy(proxy);
824///
825/// let gemini = GeminiBuilder::new("YOUR_API_KEY")
826///     .with_http_client(http_client)
827///     .build()?;
828/// # Ok(())
829/// # }
830/// ```
831pub struct GeminiBuilder {
832    key: String,
833    model: Model,
834    client_builder: ClientBuilder,
835    base_url: Url,
836}
837
838impl GeminiBuilder {
839    /// Creates a new `GeminiBuilder` with the given API key.
840    pub fn new<K: Into<String>>(key: K) -> Self {
841        Self {
842            key: key.into(),
843            model: Model::default(),
844            client_builder: ClientBuilder::default(),
845            base_url: DEFAULT_BASE_URL.clone(),
846        }
847    }
848
849    /// Sets the model for the client.
850    pub fn with_model<M: Into<Model>>(mut self, model: M) -> Self {
851        self.model = model.into();
852        self
853    }
854
855    /// Sets a custom `reqwest::ClientBuilder`.
856    pub fn with_http_client(mut self, client_builder: ClientBuilder) -> Self {
857        self.client_builder = client_builder;
858        self
859    }
860
861    /// Sets a custom base URL for the API.
862    pub fn with_base_url(mut self, base_url: Url) -> Self {
863        self.base_url = base_url;
864        self
865    }
866
867    /// Builds the `Gemini` client.
868    pub fn build(self) -> Result<Gemini, Error> {
869        Ok(Gemini {
870            client: Arc::new(GeminiClient::with_base_url(
871                self.client_builder,
872                self.key,
873                self.model,
874                self.base_url,
875            )?),
876        })
877    }
878}
879
880/// Client for the Gemini API
881#[derive(Clone)]
882pub struct Gemini {
883    client: Arc<GeminiClient>,
884}
885
886impl Gemini {
887    /// Create a new client with the specified API key
888    pub fn new<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
889        Self::with_model(api_key, Model::default())
890    }
891
892    /// Create a new client for the Gemini Pro model
893    pub fn pro<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
894        Self::with_model(api_key, Model::Gemini25Pro)
895    }
896
897    /// Create a new client for the Gemini Pro 3 image model
898    pub fn pro_image<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
899        Self::with_model(api_key, Model::Gemini3ProImage)
900    }
901
902    /// Create a new client with the specified API key and model
903    pub fn with_model<K: AsRef<str>, M: Into<Model>>(api_key: K, model: M) -> Result<Self, Error> {
904        Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.clone())
905    }
906
907    /// Create a new client with custom base URL
908    pub fn with_base_url<K: AsRef<str>>(api_key: K, base_url: Url) -> Result<Self, Error> {
909        Self::with_model_and_base_url(api_key, Model::default(), base_url)
910    }
911
912    /// Create a new client with the specified API key, model, and base URL
913    pub fn with_model_and_base_url<K: AsRef<str>, M: Into<Model>>(
914        api_key: K,
915        model: M,
916        base_url: Url,
917    ) -> Result<Self, Error> {
918        let client =
919            GeminiClient::with_base_url(Default::default(), api_key, model.into(), base_url)?;
920        Ok(Self {
921            client: Arc::new(client),
922        })
923    }
924
925    /// Start building a content generation request
926    pub fn generate_content(&self) -> ContentBuilder {
927        ContentBuilder::new(self.client.clone())
928    }
929
930    /// Start building a content embedding request
931    pub fn embed_content(&self) -> EmbedBuilder {
932        EmbedBuilder::new(self.client.clone())
933    }
934
935    /// Start building a batch content generation request
936    pub fn batch_generate_content(&self) -> BatchBuilder {
937        BatchBuilder::new(self.client.clone())
938    }
939
940    /// Get a handle to a batch operation by its name.
941    pub fn get_batch(&self, name: &str) -> BatchHandle {
942        BatchHandle::new(name.to_string(), self.client.clone())
943    }
944
945    /// Lists batch operations.
946    ///
947    /// This method returns a stream that handles pagination automatically.
948    pub fn list_batches(
949        &self,
950        page_size: impl Into<Option<u32>>,
951    ) -> impl Stream<Item = Result<BatchOperation, Error>> + Send {
952        let client = self.client.clone();
953        let page_size = page_size.into();
954        async_stream::try_stream! {
955            let mut page_token: Option<String> = None;
956            loop {
957                let response = client
958                    .list_batch_operations(page_size, page_token.clone())
959                    .await?;
960
961                for operation in response.operations {
962                    yield operation;
963                }
964
965                if let Some(next_page_token) = response.next_page_token {
966                    page_token = Some(next_page_token);
967                } else {
968                    break;
969                }
970            }
971        }
972    }
973
974    /// Create cached content with a fluent API.
975    pub fn create_cache(&self) -> CacheBuilder {
976        CacheBuilder::new(self.client.clone())
977    }
978
979    /// Get a handle to cached content by its name.
980    pub fn get_cached_content(&self, name: &str) -> CachedContentHandle {
981        CachedContentHandle::new(name.to_string(), self.client.clone())
982    }
983
984    /// Lists cached contents.
985    ///
986    /// This method returns a stream that handles pagination automatically.
987    pub fn list_cached_contents(
988        &self,
989        page_size: impl Into<Option<i32>>,
990    ) -> impl Stream<Item = Result<CachedContentSummary, Error>> + Send {
991        let client = self.client.clone();
992        let page_size = page_size.into();
993        async_stream::try_stream! {
994            let mut page_token: Option<String> = None;
995            loop {
996                let response = client
997                    .list_cached_contents(page_size, page_token.clone())
998                    .await?;
999
1000                for cached_content in response.cached_contents {
1001                    yield cached_content;
1002                }
1003
1004                if let Some(next_page_token) = response.next_page_token {
1005                    page_token = Some(next_page_token);
1006                } else {
1007                    break;
1008                }
1009            }
1010        }
1011    }
1012
1013    /// Start building a file resource
1014    pub fn create_file<B: Into<Vec<u8>>>(&self, bytes: B) -> crate::files::builder::FileBuilder {
1015        crate::files::builder::FileBuilder::new(self.client.clone(), bytes)
1016    }
1017
1018    /// Get a handle to a file by its name.
1019    pub async fn get_file(&self, name: &str) -> Result<FileHandle, Error> {
1020        let file = self.client.get_file(name).await?;
1021        Ok(FileHandle::new(self.client.clone(), file))
1022    }
1023
1024    /// Lists files.
1025    ///
1026    /// This method returns a stream that handles pagination automatically.
1027    pub fn list_files(
1028        &self,
1029        page_size: impl Into<Option<u32>>,
1030    ) -> impl Stream<Item = Result<FileHandle, Error>> + Send {
1031        let client = self.client.clone();
1032        let page_size = page_size.into();
1033        async_stream::try_stream! {
1034            let mut page_token: Option<String> = None;
1035            loop {
1036                let response = client
1037                    .list_files(page_size, page_token.clone())
1038                    .await?;
1039
1040                for file in response.files {
1041                    yield FileHandle::new(client.clone(), file);
1042                }
1043
1044                if let Some(next_page_token) = response.next_page_token {
1045                    page_token = Some(next_page_token);
1046                } else {
1047                    break;
1048                }
1049            }
1050        }
1051    }
1052}