gemini_rust/
client.rs

1use crate::{
2    batch::{BatchBuilder, BatchHandle},
3    cache::{CacheBuilder, CachedContentHandle},
4    embedding::{
5        BatchContentEmbeddingResponse, BatchEmbedContentsRequest, ContentEmbeddingResponse,
6        EmbedBuilder, EmbedContentRequest,
7    },
8    files::{
9        handle::FileHandle,
10        model::{File, ListFilesResponse},
11    },
12    generation::{ContentBuilder, GenerateContentRequest, GenerationResponse},
13};
14use eventsource_stream::{EventStreamError, Eventsource};
15use futures::{Stream, StreamExt, TryStreamExt};
16use mime::Mime;
17use reqwest::{
18    header::{HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue},
19    Client, ClientBuilder, RequestBuilder, Response,
20};
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23use snafu::{OptionExt, ResultExt, Snafu};
24use std::{
25    fmt::{self, Formatter},
26    sync::{Arc, LazyLock},
27};
28use tracing::{instrument, Level, Span};
29use url::Url;
30
31use crate::batch::model::*;
32use crate::cache::model::*;
33
34static DEFAULT_BASE_URL: LazyLock<Url> = LazyLock::new(|| {
35    Url::parse("https://generativelanguage.googleapis.com/v1beta/")
36        .expect("unreachable error: failed to parse default base URL")
37});
38
39#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
40pub enum Model {
41    #[default]
42    #[serde(rename = "models/gemini-2.5-flash")]
43    Gemini25Flash,
44    #[serde(rename = "models/gemini-2.5-flash-lite")]
45    Gemini25FlashLite,
46    #[serde(rename = "models/gemini-2.5-pro")]
47    Gemini25Pro,
48    #[serde(rename = "models/text-embedding-004")]
49    TextEmbedding004,
50    #[serde(untagged)]
51    Custom(String),
52}
53
54impl Model {
55    pub fn as_str(&self) -> &str {
56        match self {
57            Model::Gemini25Flash => "models/gemini-2.5-flash",
58            Model::Gemini25FlashLite => "models/gemini-2.5-flash-lite",
59            Model::Gemini25Pro => "models/gemini-2.5-pro",
60            Model::TextEmbedding004 => "models/text-embedding-004",
61            Model::Custom(model) => model,
62        }
63    }
64}
65
66impl From<String> for Model {
67    fn from(model: String) -> Self {
68        Self::Custom(model)
69    }
70}
71
72impl fmt::Display for Model {
73    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
74        match self {
75            Model::Gemini25Flash => write!(f, "models/gemini-2.5-flash"),
76            Model::Gemini25FlashLite => write!(f, "models/gemini-2.5-flash-lite"),
77            Model::Gemini25Pro => write!(f, "models/gemini-2.5-pro"),
78            Model::TextEmbedding004 => write!(f, "models/text-embedding-004"),
79            Model::Custom(model) => write!(f, "{}", model),
80        }
81    }
82}
83
84#[derive(Debug, Snafu)]
85#[snafu(visibility(pub))]
86pub enum Error {
87    #[snafu(display("failed to parse API key"))]
88    InvalidApiKey {
89        source: InvalidHeaderValue,
90    },
91
92    #[snafu(display("failed to construct URL (probably incorrect model name): {suffix}"))]
93    ConstructUrl {
94        source: url::ParseError,
95        suffix: String,
96    },
97
98    PerformRequestNew {
99        source: reqwest::Error,
100    },
101
102    #[snafu(display("failed to perform request to '{url}'"))]
103    PerformRequest {
104        source: reqwest::Error,
105        url: Url,
106    },
107
108    #[snafu(display(
109        "bad response from server; code {code}; description: {}",
110        description.as_deref().unwrap_or("none")
111    ))]
112    BadResponse {
113        /// HTTP status code
114        code: u16,
115        /// HTTP error description
116        description: Option<String>,
117    },
118
119    MissingResponseHeader {
120        header: String,
121    },
122
123    #[snafu(display("failed to obtain stream SSE part"))]
124    BadPart {
125        source: EventStreamError<reqwest::Error>,
126    },
127
128    #[snafu(display("failed to deserialize JSON response"))]
129    Deserialize {
130        source: serde_json::Error,
131    },
132
133    #[snafu(display("failed to generate content"))]
134    DecodeResponse {
135        source: reqwest::Error,
136    },
137
138    #[snafu(display("failed to parse URL"))]
139    UrlParse {
140        source: url::ParseError,
141    },
142
143    #[snafu(display("I/O error during file operations"))]
144    Io {
145        source: std::io::Error,
146    },
147}
148
149/// Internal client for making requests to the Gemini API
150pub struct GeminiClient {
151    http_client: Client,
152    pub model: Model,
153    base_url: Url,
154}
155
156impl GeminiClient {
157    /// Create a new client with custom base URL
158    fn with_base_url<K: AsRef<str>, M: Into<Model>>(
159        client_builder: ClientBuilder,
160        api_key: K,
161        model: M,
162        base_url: Url,
163    ) -> Result<Self, Error> {
164        let headers = HeaderMap::from_iter([(
165            HeaderName::from_static("x-goog-api-key"),
166            HeaderValue::from_str(api_key.as_ref()).context(InvalidApiKeySnafu)?,
167        )]);
168
169        let http_client = client_builder
170            .default_headers(headers)
171            .build()
172            .expect("all parameters must be valid");
173
174        Ok(Self {
175            http_client,
176            model: model.into(),
177            base_url,
178        })
179    }
180
181    /// Check the response status code and return an error if it is not successful
182    #[tracing::instrument(skip_all, err)]
183    async fn check_response(response: Response) -> Result<Response, Error> {
184        let status = response.status();
185        if !status.is_success() {
186            let description = response.text().await.ok();
187            BadResponseSnafu {
188                code: status.as_u16(),
189                description,
190            }
191            .fail()
192        } else {
193            Ok(response)
194        }
195    }
196
197    /// Performs an HTTP request to the Gemini API with standardized error handling.
198    ///
199    /// This method provides a generic way to make HTTP requests to the Gemini API with
200    /// consistent error handling, response checking, and deserialization. It handles:
201    /// - Building the HTTP request using a provided builder function
202    /// - Sending the request and handling network errors
203    /// - Checking the response status code for errors
204    /// - Deserializing the response using a provided deserializer function
205    ///
206    /// # Type Parameters
207    /// * `B` - A function that takes a `&Client` and returns a `RequestBuilder`
208    /// * `D` - An async function that takes ownership of a `Response` and returns a `Result<T, Error>`
209    /// * `T` - The type of the deserialized response
210    ///
211    /// # Note
212    /// The `AsyncFn` trait is a standard Rust feature (stabilized in v1.85) and does not
213    /// require any additional imports or feature flags.
214    ///
215    /// # Parameters
216    /// * `builder` - A function that constructs the HTTP request using the client
217    /// * `deserializer` - An async function that processes the response into the desired type
218    ///
219    /// # Examples
220    ///
221    /// Basic HTTP operations:
222    /// ```no_run
223    /// # use gemini_rust::client::*;
224    /// # use reqwest::Response;
225    /// # use url::Url;
226    /// # use serde_json::Value;
227    /// # use snafu::ResultExt;
228    /// # async fn examples(client: &GeminiClient) -> Result<(), Box<dyn std::error::Error>> {
229    /// # let url: Url = "https://example.com".parse()?;
230    /// # let request = Value::Null;
231    ///
232    /// // POST request with JSON payload
233    /// let _response = client
234    ///     .perform_request(
235    ///         |c| c.post(url.clone()).json(&request),
236    ///         async |r| r.json().await.context(DecodeResponseSnafu),
237    ///     )
238    ///     .await?;
239    ///
240    /// // GET request with JSON response
241    /// let _response = client
242    ///     .perform_request(
243    ///         |c| c.get(url.clone()),
244    ///         async |r| r.json().await.context(DecodeResponseSnafu),
245    ///     )
246    ///     .await?;
247    ///
248    /// // DELETE request with no response body
249    /// let _response = client
250    ///     .perform_request(|c| c.delete(url), async |_r| Ok(()))
251    ///     .await?;
252    /// # Ok(())
253    /// # }
254    /// ```
255    ///
256    /// Request returning a stream:
257    /// ```no_run
258    /// # use gemini_rust::client::*;
259    /// # use reqwest::Response;
260    /// # use url::Url;
261    /// # use serde_json::Value;
262    /// # async fn example(client: &GeminiClient) -> Result<(), Box<dyn std::error::Error>> {
263    /// # let url: Url = "https://example.com".parse()?;
264    /// # let request = Value::Null;
265    /// let _stream = client
266    ///     .perform_request(
267    ///         |c| c.post(url).json(&request),
268    ///         async |r| Ok(r.bytes_stream()),
269    ///     )
270    ///     .await?;
271    /// # Ok(())
272    /// # }
273    /// ```
274    #[tracing::instrument(skip_all)]
275    #[doc(hidden)]
276    pub async fn perform_request<
277        B: FnOnce(&Client) -> RequestBuilder,
278        D: AsyncFn(Response) -> Result<T, Error>,
279        T,
280    >(
281        &self,
282        builder: B,
283        deserializer: D,
284    ) -> Result<T, Error> {
285        let request = builder(&self.http_client);
286        tracing::debug!("request built successfully");
287        let response = request.send().await.context(PerformRequestNewSnafu)?;
288        tracing::debug!("response received successfully");
289        let response = Self::check_response(response).await?;
290        tracing::debug!("response ok");
291        deserializer(response).await
292    }
293
294    /// Perform a GET request and deserialize the JSON response.
295    ///
296    /// This is a convenience wrapper around [`perform_request`](Self::perform_request).
297    #[tracing::instrument(skip(self), fields(request.type = "get", request.url = %url))]
298    async fn get_json<T: serde::de::DeserializeOwned>(&self, url: Url) -> Result<T, Error> {
299        self.perform_request(
300            |c| c.get(url),
301            async |r| r.json().await.context(DecodeResponseSnafu),
302        )
303        .await
304    }
305
306    /// Perform a POST request with JSON body and deserialize the JSON response.
307    ///
308    /// This is a convenience wrapper around [`perform_request`](Self::perform_request).
309    #[tracing::instrument(skip(self, body), fields(request.type = "post", request.url = %url))]
310    async fn post_json<Req: serde::Serialize, Res: serde::de::DeserializeOwned>(
311        &self,
312        url: Url,
313        body: &Req,
314    ) -> Result<Res, Error> {
315        self.perform_request(
316            |c| c.post(url).json(body),
317            async |r| r.json().await.context(DecodeResponseSnafu),
318        )
319        .await
320    }
321
322    /// Generate content
323    #[instrument(skip_all, fields(
324        model,
325        messages.parts.count = request.contents.len(),
326        tools.present = request.tools.is_some(),
327        system.instruction.present = request.system_instruction.is_some(),
328        cached.content.present = request.cached_content.is_some(),
329        usage.prompt_tokens,
330        usage.candidates_tokens,
331        usage.thoughts_tokens,
332        usage.cached_content_tokens,
333        usage.total_tokens,
334    ), ret(level = Level::TRACE), err)]
335    pub(crate) async fn generate_content_raw(
336        &self,
337        request: GenerateContentRequest,
338    ) -> Result<GenerationResponse, Error> {
339        let url = self.build_url("generateContent")?;
340        let response: GenerationResponse = self.post_json(url, &request).await?;
341
342        // Record usage metadata
343        if let Some(usage) = &response.usage_metadata {
344            #[rustfmt::skip]
345            Span::current()
346                .record("usage.prompt_tokens", usage.prompt_token_count)
347                .record("usage.candidates_tokens", usage.candidates_token_count)
348                .record("usage.thoughts_tokens", usage.thoughts_token_count)
349                .record("usage.cached_content_tokens", usage.cached_content_token_count)
350                .record("usage.total_tokens", usage.total_token_count);
351
352            tracing::debug!("generation usage evaluated");
353        }
354
355        Ok(response)
356    }
357
358    /// Generate content with streaming
359    #[instrument(skip_all, fields(
360        model,
361        messages.parts.count = request.contents.len(),
362        tools.present = request.tools.is_some(),
363        system.instruction.present = request.system_instruction.is_some(),
364        cached.content.present = request.cached_content.is_some(),
365    ), err)]
366    pub(crate) async fn generate_content_stream(
367        &self,
368        request: GenerateContentRequest,
369    ) -> Result<impl TryStreamExt<Ok = GenerationResponse, Error = Error> + Send + use<>, Error>
370    {
371        let mut url = self.build_url("streamGenerateContent")?;
372        url.query_pairs_mut().append_pair("alt", "sse");
373
374        let stream = self
375            .perform_request(
376                |c| c.post(url).json(&request),
377                async |r| Ok(r.bytes_stream()),
378            )
379            .await?;
380
381        Ok(stream
382            .eventsource()
383            .map(|event| event.context(BadPartSnafu))
384            .map_ok(|event| {
385                serde_json::from_str::<GenerationResponse>(&event.data).context(DeserializeSnafu)
386            })
387            .map(|r| r.flatten()))
388    }
389
390    /// Embed content
391    #[instrument(skip_all, fields(
392        model,
393        task.type = request.task_type.as_ref().map(|t| format!("{:?}", t)),
394        task.title = request.title,
395        task.output.dimensionality = request.output_dimensionality,
396    ))]
397    pub(crate) async fn embed_content(
398        &self,
399        request: EmbedContentRequest,
400    ) -> Result<ContentEmbeddingResponse, Error> {
401        let url = self.build_url("embedContent")?;
402        self.post_json(url, &request).await
403    }
404
405    /// Batch Embed content
406    #[instrument(skip_all, fields(batch.size = request.requests.len()))]
407    pub(crate) async fn embed_content_batch(
408        &self,
409        request: BatchEmbedContentsRequest,
410    ) -> Result<BatchContentEmbeddingResponse, Error> {
411        let url = self.build_url("batchEmbedContents")?;
412        self.post_json(url, &request).await
413    }
414
415    /// Batch generate content (synchronous API that returns results immediately)
416    #[instrument(skip_all, fields(
417        batch.display_name = request.batch.display_name,
418        batch.size = request.batch.input_config.batch_size(),
419    ))]
420    pub(crate) async fn batch_generate_content(
421        &self,
422        request: BatchGenerateContentRequest,
423    ) -> Result<BatchGenerateContentResponse, Error> {
424        let url = self.build_url("batchGenerateContent")?;
425        self.post_json(url, &request).await
426    }
427
428    /// Get a batch operation
429    #[instrument(skip_all, fields(
430        operation.name = name,
431    ))]
432    pub(crate) async fn get_batch_operation<T: serde::de::DeserializeOwned>(
433        &self,
434        name: &str,
435    ) -> Result<T, Error> {
436        let url = self.build_batch_url(name, None)?;
437        self.get_json(url).await
438    }
439
440    /// List batch operations
441    #[instrument(skip_all, fields(
442        page.size = page_size,
443        page.token.present = page_token.is_some(),
444    ))]
445    pub(crate) async fn list_batch_operations(
446        &self,
447        page_size: Option<u32>,
448        page_token: Option<String>,
449    ) -> Result<ListBatchesResponse, Error> {
450        let mut url = self.build_batch_url("batches", None)?;
451
452        if let Some(size) = page_size {
453            url.query_pairs_mut()
454                .append_pair("pageSize", &size.to_string());
455        }
456        if let Some(token) = page_token {
457            url.query_pairs_mut().append_pair("pageToken", &token);
458        }
459
460        self.get_json(url).await
461    }
462
463    /// List files
464    #[instrument(skip_all, fields(
465        page.size = page_size,
466        page.token.present = page_token.is_some(),
467    ))]
468    pub(crate) async fn list_files(
469        &self,
470        page_size: Option<u32>,
471        page_token: Option<String>,
472    ) -> Result<ListFilesResponse, Error> {
473        let mut url = self.build_files_url(None)?;
474
475        if let Some(size) = page_size {
476            url.query_pairs_mut()
477                .append_pair("pageSize", &size.to_string());
478        }
479        if let Some(token) = page_token {
480            url.query_pairs_mut().append_pair("pageToken", &token);
481        }
482
483        self.get_json(url).await
484    }
485
486    /// Cancel a batch operation
487    #[instrument(skip_all, fields(
488        operation.name = name,
489    ))]
490    pub(crate) async fn cancel_batch_operation(&self, name: &str) -> Result<(), Error> {
491        let url = self.build_batch_url(name, Some("cancel"))?;
492        self.perform_request(|c| c.post(url).json(&json!({})), async |_r| Ok(()))
493            .await
494    }
495
496    /// Delete a batch operation
497    #[instrument(skip_all, fields(
498        operation.name = name,
499    ))]
500    pub(crate) async fn delete_batch_operation(&self, name: &str) -> Result<(), Error> {
501        let url = self.build_batch_url(name, None)?;
502        self.perform_request(|c| c.delete(url), async |_r| Ok(()))
503            .await
504    }
505
506    async fn create_upload(
507        &self,
508        bytes: usize,
509        display_name: Option<String>,
510        mime_type: Mime,
511    ) -> Result<Url, Error> {
512        let url = self
513            .base_url
514            .join("/upload/v1beta/files")
515            .context(ConstructUrlSnafu {
516                suffix: "/upload/v1beta/files".to_string(),
517            })?;
518
519        self.perform_request(
520            |c| {
521                c.post(url)
522                    .header("X-Goog-Upload-Protocol", "resumable")
523                    .header("X-Goog-Upload-Command", "start")
524                    .header("X-Goog-Upload-Content-Length", bytes.to_string())
525                    .header("X-Goog-Upload-Header-Content-Type", mime_type.to_string())
526                    .json(&json!({"file": {"displayName": display_name}}))
527            },
528            async |r| {
529                r.headers()
530                    .get("X-Goog-Upload-URL")
531                    .context(MissingResponseHeaderSnafu {
532                        header: "X-Goog-Upload-URL",
533                    })
534                    .and_then(|upload_url| {
535                        upload_url
536                            .to_str()
537                            .map(str::to_string)
538                            .map_err(|_| Error::BadResponse {
539                                code: 500,
540                                description: Some("Missing upload URL in response".to_string()),
541                            })
542                    })
543                    .and_then(|url| Url::parse(&url).context(UrlParseSnafu))
544            },
545        )
546        .await
547    }
548
549    /// Upload a file using the resumable upload protocol.
550    #[instrument(skip_all, fields(
551        file.size = file_bytes.len(),
552        mime.type = mime_type.to_string(),
553        file.display_name = display_name.as_deref(),
554    ))]
555    pub(crate) async fn upload_file(
556        &self,
557        display_name: Option<String>,
558        file_bytes: Vec<u8>,
559        mime_type: Mime,
560    ) -> Result<File, Error> {
561        // Step 1: Create resumable upload session
562        let upload_url = self
563            .create_upload(file_bytes.len(), display_name, mime_type)
564            .await?;
565
566        // Step 2: Upload file content
567        let upload_response = self
568            .http_client
569            .post(upload_url.clone())
570            .header("X-Goog-Upload-Command", "upload, finalize")
571            .header("X-Goog-Upload-Offset", "0")
572            .body(file_bytes)
573            .send()
574            .await
575            .map_err(|e| Error::PerformRequest {
576                source: e,
577                url: upload_url,
578            })?;
579
580        let final_response = Self::check_response(upload_response).await?;
581
582        #[derive(serde::Deserialize)]
583        struct UploadResponse {
584            file: File,
585        }
586
587        let upload_response: UploadResponse =
588            final_response.json().await.context(DecodeResponseSnafu)?;
589        Ok(upload_response.file)
590    }
591
592    /// Get a file resource
593    #[instrument(skip_all, fields(
594        file.name = name,
595    ))]
596    pub(crate) async fn get_file(&self, name: &str) -> Result<File, Error> {
597        let url = self.build_files_url(Some(name))?;
598        self.get_json(url).await
599    }
600
601    /// Delete a file resource
602    #[instrument(skip_all, fields(
603        file.name = name,
604    ))]
605    pub(crate) async fn delete_file(&self, name: &str) -> Result<(), Error> {
606        let url = self.build_files_url(Some(name))?;
607        self.perform_request(|c| c.delete(url), async |_r| Ok(()))
608            .await
609    }
610
611    /// Download a file resource
612    #[instrument(skip_all, fields(
613        file.name = name,
614    ))]
615    pub(crate) async fn download_file(&self, name: &str) -> Result<Vec<u8>, Error> {
616        let mut url = self
617            .base_url
618            .join(&format!("/download/v1beta/{name}:download"))
619            .context(ConstructUrlSnafu {
620                suffix: format!("/download/v1beta/{name}:download"),
621            })?;
622        url.query_pairs_mut().append_pair("alt", "media");
623
624        self.perform_request(
625            |c| c.get(url),
626            async |r| {
627                r.bytes()
628                    .await
629                    .context(DecodeResponseSnafu)
630                    .map(|bytes| bytes.to_vec())
631            },
632        )
633        .await
634    }
635
636    /// Create cached content
637    pub(crate) async fn create_cached_content(
638        &self,
639        cached_content: CreateCachedContentRequest,
640    ) -> Result<CachedContent, Error> {
641        let url = self.build_cache_url(None)?;
642        self.post_json(url, &cached_content).await
643    }
644
645    /// Get cached content
646    pub(crate) async fn get_cached_content(&self, name: &str) -> Result<CachedContent, Error> {
647        let url = self.build_cache_url(Some(name))?;
648        self.get_json(url).await
649    }
650
651    /// Update cached content (typically to update TTL)
652    pub(crate) async fn update_cached_content(
653        &self,
654        name: &str,
655        expiration: CacheExpirationRequest,
656    ) -> Result<CachedContent, Error> {
657        let url = self.build_cache_url(Some(name))?;
658
659        // Create a minimal update payload with just the expiration
660        let update_payload = match expiration {
661            CacheExpirationRequest::Ttl { ttl } => json!({ "ttl": ttl }),
662            CacheExpirationRequest::ExpireTime { expire_time } => {
663                json!({ "expireTime": expire_time.format(&time::format_description::well_known::Rfc3339).unwrap() })
664            }
665        };
666
667        self.perform_request(
668            |c| c.patch(url.clone()).json(&update_payload),
669            async |r| r.json().await.context(DecodeResponseSnafu),
670        )
671        .await
672    }
673
674    /// Delete cached content
675    pub(crate) async fn delete_cached_content(&self, name: &str) -> Result<(), Error> {
676        let url = self.build_cache_url(Some(name))?;
677        self.perform_request(|c| c.delete(url.clone()), async |_r| Ok(()))
678            .await
679    }
680
681    /// List cached contents
682    pub(crate) async fn list_cached_contents(
683        &self,
684        page_size: Option<i32>,
685        page_token: Option<String>,
686    ) -> Result<ListCachedContentsResponse, Error> {
687        let mut url = self.build_cache_url(None)?;
688
689        if let Some(size) = page_size {
690            url.query_pairs_mut()
691                .append_pair("pageSize", &size.to_string());
692        }
693        if let Some(token) = page_token {
694            url.query_pairs_mut().append_pair("pageToken", &token);
695        }
696
697        self.get_json(url).await
698    }
699
700    /// Build a URL with the given suffix
701    #[tracing::instrument(skip(self), ret(level = Level::DEBUG))]
702    fn build_url_with_suffix(&self, suffix: &str) -> Result<Url, Error> {
703        self.base_url.join(suffix).context(ConstructUrlSnafu {
704            suffix: suffix.to_string(),
705        })
706    }
707
708    /// Build a URL for the API
709    #[tracing::instrument(skip(self), ret(level = Level::DEBUG))]
710    fn build_url(&self, endpoint: &str) -> Result<Url, Error> {
711        let suffix = format!("{}:{endpoint}", self.model);
712        self.build_url_with_suffix(&suffix)
713    }
714
715    /// Build a URL for a batch operation
716    fn build_batch_url(&self, name: &str, action: Option<&str>) -> Result<Url, Error> {
717        let suffix = action
718            .map(|a| format!("{name}:{a}"))
719            .unwrap_or_else(|| name.to_string());
720        self.build_url_with_suffix(&suffix)
721    }
722
723    /// Build a URL for file operations
724    fn build_files_url(&self, name: Option<&str>) -> Result<Url, Error> {
725        let suffix = name
726            .map(|n| format!("files/{}", n.strip_prefix("files/").unwrap_or(n)))
727            .unwrap_or_else(|| "files".to_string());
728        self.build_url_with_suffix(&suffix)
729    }
730
731    /// Build a URL for cache operations
732    fn build_cache_url(&self, name: Option<&str>) -> Result<Url, Error> {
733        let suffix = name
734            .map(|n| {
735                if n.starts_with("cachedContents/") {
736                    n.to_string()
737                } else {
738                    format!("cachedContents/{}", n)
739                }
740            })
741            .unwrap_or_else(|| "cachedContents".to_string());
742        self.build_url_with_suffix(&suffix)
743    }
744}
745
746/// A builder for the `Gemini` client.
747///
748/// # Examples
749///
750/// ## Basic usage
751///
752/// ```no_run
753/// use gemini_rust::{GeminiBuilder, Model};
754///
755/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
756/// let gemini = GeminiBuilder::new("YOUR_API_KEY")
757///     .with_model(Model::Gemini25Pro)
758///     .build()?;
759/// # Ok(())
760/// # }
761/// ```
762///
763/// ## With proxy configuration
764///
765/// ```no_run
766/// use gemini_rust::{GeminiBuilder, Model};
767/// use reqwest::{ClientBuilder, Proxy};
768///
769/// # async fn run() -> Result<(), Box<dyn std::error::Error>> {
770/// let proxy = Proxy::https("https://my.proxy")?;
771/// let http_client = ClientBuilder::new().proxy(proxy);
772///
773/// let gemini = GeminiBuilder::new("YOUR_API_KEY")
774///     .with_http_client(http_client)
775///     .build()?;
776/// # Ok(())
777/// # }
778/// ```
779pub struct GeminiBuilder {
780    key: String,
781    model: Model,
782    client_builder: ClientBuilder,
783    base_url: Url,
784}
785
786impl GeminiBuilder {
787    /// Creates a new `GeminiBuilder` with the given API key.
788    pub fn new<K: Into<String>>(key: K) -> Self {
789        Self {
790            key: key.into(),
791            model: Model::default(),
792            client_builder: ClientBuilder::default(),
793            base_url: DEFAULT_BASE_URL.clone(),
794        }
795    }
796
797    /// Sets the model for the client.
798    pub fn with_model<M: Into<Model>>(mut self, model: M) -> Self {
799        self.model = model.into();
800        self
801    }
802
803    /// Sets a custom `reqwest::ClientBuilder`.
804    pub fn with_http_client(mut self, client_builder: ClientBuilder) -> Self {
805        self.client_builder = client_builder;
806        self
807    }
808
809    /// Sets a custom base URL for the API.
810    pub fn with_base_url(mut self, base_url: Url) -> Self {
811        self.base_url = base_url;
812        self
813    }
814
815    /// Builds the `Gemini` client.
816    pub fn build(self) -> Result<Gemini, Error> {
817        Ok(Gemini {
818            client: Arc::new(GeminiClient::with_base_url(
819                self.client_builder,
820                self.key,
821                self.model,
822                self.base_url,
823            )?),
824        })
825    }
826}
827
828/// Client for the Gemini API
829#[derive(Clone)]
830pub struct Gemini {
831    client: Arc<GeminiClient>,
832}
833
834impl Gemini {
835    /// Create a new client with the specified API key
836    pub fn new<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
837        Self::with_model(api_key, Model::default())
838    }
839
840    /// Create a new client for the Gemini Pro model
841    pub fn pro<K: AsRef<str>>(api_key: K) -> Result<Self, Error> {
842        Self::with_model(api_key, Model::Gemini25Pro)
843    }
844
845    /// Create a new client with the specified API key and model
846    pub fn with_model<K: AsRef<str>, M: Into<Model>>(api_key: K, model: M) -> Result<Self, Error> {
847        Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.clone())
848    }
849
850    /// Create a new client with custom base URL
851    pub fn with_base_url<K: AsRef<str>>(api_key: K, base_url: Url) -> Result<Self, Error> {
852        Self::with_model_and_base_url(api_key, Model::default(), base_url)
853    }
854
855    /// Create a new client with the specified API key, model, and base URL
856    pub fn with_model_and_base_url<K: AsRef<str>, M: Into<Model>>(
857        api_key: K,
858        model: M,
859        base_url: Url,
860    ) -> Result<Self, Error> {
861        let client =
862            GeminiClient::with_base_url(Default::default(), api_key, model.into(), base_url)?;
863        Ok(Self {
864            client: Arc::new(client),
865        })
866    }
867
868    /// Start building a content generation request
869    pub fn generate_content(&self) -> ContentBuilder {
870        ContentBuilder::new(self.client.clone())
871    }
872
873    /// Start building a content embedding request
874    pub fn embed_content(&self) -> EmbedBuilder {
875        EmbedBuilder::new(self.client.clone())
876    }
877
878    /// Start building a batch content generation request
879    pub fn batch_generate_content(&self) -> BatchBuilder {
880        BatchBuilder::new(self.client.clone())
881    }
882
883    /// Get a handle to a batch operation by its name.
884    pub fn get_batch(&self, name: &str) -> BatchHandle {
885        BatchHandle::new(name.to_string(), self.client.clone())
886    }
887
888    /// Lists batch operations.
889    ///
890    /// This method returns a stream that handles pagination automatically.
891    pub fn list_batches(
892        &self,
893        page_size: impl Into<Option<u32>>,
894    ) -> impl Stream<Item = Result<BatchOperation, Error>> + Send {
895        let client = self.client.clone();
896        let page_size = page_size.into();
897        async_stream::try_stream! {
898            let mut page_token: Option<String> = None;
899            loop {
900                let response = client
901                    .list_batch_operations(page_size, page_token.clone())
902                    .await?;
903
904                for operation in response.operations {
905                    yield operation;
906                }
907
908                if let Some(next_page_token) = response.next_page_token {
909                    page_token = Some(next_page_token);
910                } else {
911                    break;
912                }
913            }
914        }
915    }
916
917    /// Create cached content with a fluent API.
918    pub fn create_cache(&self) -> CacheBuilder {
919        CacheBuilder::new(self.client.clone())
920    }
921
922    /// Get a handle to cached content by its name.
923    pub fn get_cached_content(&self, name: &str) -> CachedContentHandle {
924        CachedContentHandle::new(name.to_string(), self.client.clone())
925    }
926
927    /// Lists cached contents.
928    ///
929    /// This method returns a stream that handles pagination automatically.
930    pub fn list_cached_contents(
931        &self,
932        page_size: impl Into<Option<i32>>,
933    ) -> impl Stream<Item = Result<CachedContentSummary, Error>> + Send {
934        let client = self.client.clone();
935        let page_size = page_size.into();
936        async_stream::try_stream! {
937            let mut page_token: Option<String> = None;
938            loop {
939                let response = client
940                    .list_cached_contents(page_size, page_token.clone())
941                    .await?;
942
943                for cached_content in response.cached_contents {
944                    yield cached_content;
945                }
946
947                if let Some(next_page_token) = response.next_page_token {
948                    page_token = Some(next_page_token);
949                } else {
950                    break;
951                }
952            }
953        }
954    }
955
956    /// Start building a file resource
957    pub fn create_file<B: Into<Vec<u8>>>(&self, bytes: B) -> crate::files::builder::FileBuilder {
958        crate::files::builder::FileBuilder::new(self.client.clone(), bytes)
959    }
960
961    /// Get a handle to a file by its name.
962    pub async fn get_file(&self, name: &str) -> Result<FileHandle, Error> {
963        let file = self.client.get_file(name).await?;
964        Ok(FileHandle::new(self.client.clone(), file))
965    }
966
967    /// Lists files.
968    ///
969    /// This method returns a stream that handles pagination automatically.
970    pub fn list_files(
971        &self,
972        page_size: impl Into<Option<u32>>,
973    ) -> impl Stream<Item = Result<FileHandle, Error>> + Send {
974        let client = self.client.clone();
975        let page_size = page_size.into();
976        async_stream::try_stream! {
977            let mut page_token: Option<String> = None;
978            loop {
979                let response = client
980                    .list_files(page_size, page_token.clone())
981                    .await?;
982
983                for file in response.files {
984                    yield FileHandle::new(client.clone(), file);
985                }
986
987                if let Some(next_page_token) = response.next_page_token {
988                    page_token = Some(next_page_token);
989                } else {
990                    break;
991                }
992            }
993        }
994    }
995}