gemini_rust/
client.rs

1use crate::{
2    batch_builder::BatchBuilder,
3    content_builder::ContentBuilder,
4    embed_builder::EmbedBuilder,
5    models::{
6        BatchContentEmbeddingResponse, BatchEmbedContentsRequest, BatchGenerateContentRequest,
7        BatchGenerateContentResponse, BatchOperation, ContentEmbeddingResponse,
8        EmbedContentRequest, GenerateContentRequest, GenerationResponse, ListBatchesResponse,
9    },
10    Batch, Error, Result,
11};
12use futures::stream::Stream;
13use reqwest::Client;
14use serde_json::Value;
15use std::{pin::Pin, sync::Arc};
16use url::Url;
17
18const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
19const DEFAULT_MODEL: &str = "models/gemini-2.5-flash";
20
21/// Internal client for making requests to the Gemini API
22pub(crate) struct GeminiClient {
23    http_client: Client,
24    api_key: String,
25    pub model: String,
26    base_url: String,
27}
28
29impl GeminiClient {
30    /// Create a new client
31    #[allow(dead_code)]
32    fn new(api_key: impl Into<String>, model: String) -> Self {
33        Self::with_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
34    }
35
36    /// Create a new client with custom base URL
37    fn with_base_url(api_key: impl Into<String>, model: String, base_url: String) -> Self {
38        Self {
39            http_client: Client::new(),
40            api_key: api_key.into(),
41            model,
42            base_url,
43        }
44    }
45
46    /// Generate content
47    pub(crate) async fn generate_content_raw(
48        &self,
49        request: GenerateContentRequest,
50    ) -> Result<GenerationResponse> {
51        let url = self.build_url("generateContent")?;
52
53        let response = self.http_client.post(url).json(&request).send().await?;
54
55        let status = response.status();
56        if !status.is_success() {
57            let error_text = response.text().await?;
58            return Err(Error::ApiError {
59                status_code: status.as_u16(),
60                message: error_text,
61            });
62        }
63
64        let response = response.json().await?;
65
66        Ok(response)
67    }
68
69    /// Generate content with streaming
70    pub(crate) async fn generate_content_stream(
71        &self,
72        request: GenerateContentRequest,
73    ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
74        let url = self.build_url("streamGenerateContent")?;
75
76        let response = self.http_client.post(url).json(&request).send().await?;
77
78        let status = response.status();
79        if !status.is_success() {
80            let error_text = response.text().await?;
81            return Err(Error::ApiError {
82                status_code: status.as_u16(),
83                message: error_text,
84            });
85        }
86
87        // Get the full response as bytes and parse as JSON array
88        let bytes = response.bytes().await?;
89        let text = String::from_utf8_lossy(&bytes);
90
91        // The Gemini API returns a JSON array format like: [{json1}, {json2}, {json3}]
92        let responses: Vec<Result<GenerationResponse>> =
93            match serde_json::from_str::<Vec<GenerationResponse>>(&text) {
94                Ok(json_array) => json_array.into_iter().map(Ok).collect(),
95                Err(e) => {
96                    vec![Err(Error::JsonError(e))]
97                }
98            };
99
100        let stream = futures::stream::iter(responses);
101        Ok(Box::pin(stream))
102    }
103
104    /// Embed content
105    pub(crate) async fn embed_content(
106        &self,
107        request: EmbedContentRequest,
108    ) -> Result<ContentEmbeddingResponse> {
109        let value = self.post_json(request, "embedContent").await?;
110        let response = serde_json::from_value::<ContentEmbeddingResponse>(value)?;
111
112        Ok(response)
113    }
114
115    /// Batch Embed content
116    pub(crate) async fn embed_content_batch(
117        &self,
118        request: BatchEmbedContentsRequest,
119    ) -> Result<BatchContentEmbeddingResponse> {
120        let value = self.post_json(request, "batchEmbedContents").await?;
121        let response = serde_json::from_value::<BatchContentEmbeddingResponse>(value)?;
122
123        Ok(response)
124    }
125
126    /// Synchronous Batch Generate content
127    pub(crate) async fn batch_generate_content_sync(
128        &self,
129        request: BatchGenerateContentRequest,
130    ) -> Result<BatchGenerateContentResponse> {
131        let value = self.post_json(request, "batchGenerateContent").await?;
132        let response = serde_json::from_value::<BatchGenerateContentResponse>(value)?;
133        Ok(response)
134    }
135
136    /// Get a batch operation
137    pub(crate) async fn get_batch_operation<T: serde::de::DeserializeOwned>(
138        &self,
139        name: &str,
140    ) -> Result<T> {
141        let url = self.build_batch_url(name, None)?;
142        let response = self.http_client.get(url).send().await?;
143
144        let status = response.status();
145        if !status.is_success() {
146            let error_text = response.text().await?;
147            return Err(Error::ApiError {
148                status_code: status.as_u16(),
149                message: error_text,
150            });
151        }
152
153        let response = response.json().await?;
154        Ok(response)
155    }
156
157    /// List batch operations
158    pub(crate) async fn list_batch_operations(
159        &self,
160        page_size: Option<u32>,
161        page_token: Option<String>,
162    ) -> Result<ListBatchesResponse> {
163        let mut url = self.build_batch_url("batches", None)?;
164
165        if let Some(size) = page_size {
166            url.query_pairs_mut()
167                .append_pair("pageSize", &size.to_string());
168        }
169        if let Some(token) = page_token {
170            url.query_pairs_mut().append_pair("pageToken", &token);
171        }
172
173        let response = self.http_client.get(url).send().await?;
174
175        let status = response.status();
176        if !status.is_success() {
177            let error_text = response.text().await?;
178            return Err(Error::ApiError {
179                status_code: status.as_u16(),
180                message: error_text,
181            });
182        }
183
184        let response = response.json().await?;
185        Ok(response)
186    }
187
188    /// Cancel a batch operation
189    pub(crate) async fn cancel_batch_operation(&self, name: &str) -> Result<()> {
190        let url = self.build_batch_url(name, Some("cancel"))?;
191        let response = self
192            .http_client
193            .post(url)
194            .json(&serde_json::json!({}))
195            .send()
196            .await?;
197
198        let status = response.status();
199        if !status.is_success() {
200            let error_text = response.text().await?;
201            return Err(Error::ApiError {
202                status_code: status.as_u16(),
203                message: error_text,
204            });
205        }
206
207        Ok(())
208    }
209
210    /// Delete a batch operation
211    pub(crate) async fn delete_batch_operation(&self, name: &str) -> Result<()> {
212        let url = self.build_batch_url(name, None)?;
213        let response = self.http_client.delete(url).send().await?;
214
215        let status = response.status();
216        if !status.is_success() {
217            let error_text = response.text().await?;
218            return Err(Error::ApiError {
219                status_code: status.as_u16(),
220                message: error_text,
221            });
222        }
223
224        Ok(())
225    }
226
227    /// Post JSON to an endpoint
228    async fn post_json<T: serde::Serialize>(&self, request: T, endpoint: &str) -> Result<Value> {
229        let url = self.build_url(endpoint)?;
230
231        let response = self.http_client.post(url).json(&request).send().await?;
232
233        let status = response.status();
234        if !status.is_success() {
235            let error_text = response.text().await?;
236            return Err(Error::ApiError {
237                status_code: status.as_u16(),
238                message: error_text,
239            });
240        }
241
242        let response = response.json().await?;
243        Ok(response)
244    }
245
246    /// Build a URL for the API
247    fn build_url(&self, endpoint: &str) -> Result<Url> {
248        let url_str = format!(
249            "{}{}:{}?key={}",
250            self.base_url, self.model, endpoint, self.api_key
251        );
252        Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
253    }
254
255    /// Build a URL for a batch operation
256    fn build_batch_url(&self, name: &str, action: Option<&str>) -> Result<Url> {
257        let action_suffix = action.map_or("".to_string(), |a| format!(":{}", a));
258        let url_str = format!(
259            "{}{}{}?key={}",
260            self.base_url, name, action_suffix, self.api_key
261        );
262        Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
263    }
264}
265
266/// Client for the Gemini API
267#[derive(Clone)]
268pub struct Gemini {
269    client: Arc<GeminiClient>,
270}
271
272impl Gemini {
273    /// Create a new client with the specified API key
274    pub fn new(api_key: impl Into<String>) -> Self {
275        Self::with_model(api_key, DEFAULT_MODEL.to_string())
276    }
277
278    /// Create a new client for the Gemini Pro model
279    pub fn pro(api_key: impl Into<String>) -> Self {
280        Self::with_model(api_key, "models/gemini-2.5-pro".to_string())
281    }
282
283    /// Create a new client with the specified API key and model
284    pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
285        Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
286    }
287
288    /// Create a new client with custom base URL
289    pub fn with_base_url(api_key: impl Into<String>, base_url: String) -> Self {
290        Self::with_model_and_base_url(api_key, DEFAULT_MODEL.to_string(), base_url)
291    }
292
293    /// Create a new client with the specified API key, model, and base URL
294    pub fn with_model_and_base_url(
295        api_key: impl Into<String>,
296        model: String,
297        base_url: String,
298    ) -> Self {
299        let client = GeminiClient::with_base_url(api_key.into(), model, base_url);
300        Self {
301            client: Arc::new(client),
302        }
303    }
304
305    /// Start building a content generation request
306    pub fn generate_content(&self) -> ContentBuilder {
307        ContentBuilder::new(self.client.clone())
308    }
309
310    /// Start building a content generation request
311    pub fn embed_content(&self) -> EmbedBuilder {
312        EmbedBuilder::new(self.client.clone())
313    }
314
315    /// Start building a synchronous batch content generation request
316    pub fn batch_generate_content_sync(&self) -> BatchBuilder {
317        BatchBuilder::new(self.client.clone())
318    }
319
320    /// Get a handle to a batch operation by its name.
321    pub fn get_batch(&self, name: &str) -> Batch {
322        Batch::new(name.to_string(), self.client.clone())
323    }
324
325    /// Lists batch operations.
326    ///
327    /// This method returns a stream that handles pagination automatically.
328    pub fn list_batches(
329        &self,
330        page_size: impl Into<Option<u32>>,
331    ) -> impl Stream<Item = Result<BatchOperation>> + Send {
332        let client = self.client.clone();
333        let page_size = page_size.into();
334        async_stream::try_stream! {
335            let mut page_token: Option<String> = None;
336            loop {
337                let response = client
338                    .list_batch_operations(page_size, page_token.clone())
339                    .await?;
340
341                for operation in response.operations {
342                    yield operation;
343                }
344
345                if let Some(next_page_token) = response.next_page_token {
346                    page_token = Some(next_page_token);
347                } else {
348                    break;
349                }
350            }
351        }
352    }
353}