1use crate::{
2 batch_builder::BatchBuilder,
3 content_builder::ContentBuilder,
4 embed_builder::EmbedBuilder,
5 models::{
6 BatchContentEmbeddingResponse, BatchEmbedContentsRequest, BatchGenerateContentRequest,
7 BatchGenerateContentResponse, BatchOperation, ContentEmbeddingResponse,
8 EmbedContentRequest, GenerateContentRequest, GenerationResponse, ListBatchesResponse,
9 },
10 Batch, Error, Result,
11};
12use futures::stream::Stream;
13use reqwest::Client;
14use serde_json::Value;
15use std::{pin::Pin, sync::Arc};
16use url::Url;
17
18const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
19const DEFAULT_MODEL: &str = "models/gemini-2.5-flash";
20
21pub(crate) struct GeminiClient {
23 http_client: Client,
24 api_key: String,
25 pub model: String,
26 base_url: String,
27}
28
29impl GeminiClient {
30 #[allow(dead_code)]
32 fn new(api_key: impl Into<String>, model: String) -> Self {
33 Self::with_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
34 }
35
36 fn with_base_url(api_key: impl Into<String>, model: String, base_url: String) -> Self {
38 Self {
39 http_client: Client::new(),
40 api_key: api_key.into(),
41 model,
42 base_url,
43 }
44 }
45
46 pub(crate) async fn generate_content_raw(
48 &self,
49 request: GenerateContentRequest,
50 ) -> Result<GenerationResponse> {
51 let url = self.build_url("generateContent")?;
52
53 let response = self.http_client.post(url).json(&request).send().await?;
54
55 let status = response.status();
56 if !status.is_success() {
57 let error_text = response.text().await?;
58 return Err(Error::ApiError {
59 status_code: status.as_u16(),
60 message: error_text,
61 });
62 }
63
64 let response = response.json().await?;
65
66 Ok(response)
67 }
68
69 pub(crate) async fn generate_content_stream(
71 &self,
72 request: GenerateContentRequest,
73 ) -> Result<Pin<Box<dyn Stream<Item = Result<GenerationResponse>> + Send>>> {
74 let url = self.build_url("streamGenerateContent")?;
75
76 let response = self.http_client.post(url).json(&request).send().await?;
77
78 let status = response.status();
79 if !status.is_success() {
80 let error_text = response.text().await?;
81 return Err(Error::ApiError {
82 status_code: status.as_u16(),
83 message: error_text,
84 });
85 }
86
87 let bytes = response.bytes().await?;
89 let text = String::from_utf8_lossy(&bytes);
90
91 let responses: Vec<Result<GenerationResponse>> =
93 match serde_json::from_str::<Vec<GenerationResponse>>(&text) {
94 Ok(json_array) => json_array.into_iter().map(Ok).collect(),
95 Err(e) => {
96 vec![Err(Error::JsonError(e))]
97 }
98 };
99
100 let stream = futures::stream::iter(responses);
101 Ok(Box::pin(stream))
102 }
103
104 pub(crate) async fn embed_content(
106 &self,
107 request: EmbedContentRequest,
108 ) -> Result<ContentEmbeddingResponse> {
109 let value = self.post_json(request, "embedContent").await?;
110 let response = serde_json::from_value::<ContentEmbeddingResponse>(value)?;
111
112 Ok(response)
113 }
114
115 pub(crate) async fn embed_content_batch(
117 &self,
118 request: BatchEmbedContentsRequest,
119 ) -> Result<BatchContentEmbeddingResponse> {
120 let value = self.post_json(request, "batchEmbedContents").await?;
121 let response = serde_json::from_value::<BatchContentEmbeddingResponse>(value)?;
122
123 Ok(response)
124 }
125
126 pub(crate) async fn batch_generate_content_sync(
128 &self,
129 request: BatchGenerateContentRequest,
130 ) -> Result<BatchGenerateContentResponse> {
131 let value = self.post_json(request, "batchGenerateContent").await?;
132 let response = serde_json::from_value::<BatchGenerateContentResponse>(value)?;
133 Ok(response)
134 }
135
136 pub(crate) async fn get_batch_operation<T: serde::de::DeserializeOwned>(
138 &self,
139 name: &str,
140 ) -> Result<T> {
141 let url = self.build_batch_url(name, None)?;
142 let response = self.http_client.get(url).send().await?;
143
144 let status = response.status();
145 if !status.is_success() {
146 let error_text = response.text().await?;
147 return Err(Error::ApiError {
148 status_code: status.as_u16(),
149 message: error_text,
150 });
151 }
152
153 let response = response.json().await?;
154 Ok(response)
155 }
156
157 pub(crate) async fn list_batch_operations(
159 &self,
160 page_size: Option<u32>,
161 page_token: Option<String>,
162 ) -> Result<ListBatchesResponse> {
163 let mut url = self.build_batch_url("batches", None)?;
164
165 if let Some(size) = page_size {
166 url.query_pairs_mut()
167 .append_pair("pageSize", &size.to_string());
168 }
169 if let Some(token) = page_token {
170 url.query_pairs_mut().append_pair("pageToken", &token);
171 }
172
173 let response = self.http_client.get(url).send().await?;
174
175 let status = response.status();
176 if !status.is_success() {
177 let error_text = response.text().await?;
178 return Err(Error::ApiError {
179 status_code: status.as_u16(),
180 message: error_text,
181 });
182 }
183
184 let response = response.json().await?;
185 Ok(response)
186 }
187
188 pub(crate) async fn cancel_batch_operation(&self, name: &str) -> Result<()> {
190 let url = self.build_batch_url(name, Some("cancel"))?;
191 let response = self
192 .http_client
193 .post(url)
194 .json(&serde_json::json!({}))
195 .send()
196 .await?;
197
198 let status = response.status();
199 if !status.is_success() {
200 let error_text = response.text().await?;
201 return Err(Error::ApiError {
202 status_code: status.as_u16(),
203 message: error_text,
204 });
205 }
206
207 Ok(())
208 }
209
210 pub(crate) async fn delete_batch_operation(&self, name: &str) -> Result<()> {
212 let url = self.build_batch_url(name, None)?;
213 let response = self.http_client.delete(url).send().await?;
214
215 let status = response.status();
216 if !status.is_success() {
217 let error_text = response.text().await?;
218 return Err(Error::ApiError {
219 status_code: status.as_u16(),
220 message: error_text,
221 });
222 }
223
224 Ok(())
225 }
226
227 async fn post_json<T: serde::Serialize>(&self, request: T, endpoint: &str) -> Result<Value> {
229 let url = self.build_url(endpoint)?;
230
231 let response = self.http_client.post(url).json(&request).send().await?;
232
233 let status = response.status();
234 if !status.is_success() {
235 let error_text = response.text().await?;
236 return Err(Error::ApiError {
237 status_code: status.as_u16(),
238 message: error_text,
239 });
240 }
241
242 let response = response.json().await?;
243 Ok(response)
244 }
245
246 fn build_url(&self, endpoint: &str) -> Result<Url> {
248 let url_str = format!(
249 "{}{}:{}?key={}",
250 self.base_url, self.model, endpoint, self.api_key
251 );
252 Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
253 }
254
255 fn build_batch_url(&self, name: &str, action: Option<&str>) -> Result<Url> {
257 let action_suffix = action.map_or("".to_string(), |a| format!(":{}", a));
258 let url_str = format!(
259 "{}{}{}?key={}",
260 self.base_url, name, action_suffix, self.api_key
261 );
262 Url::parse(&url_str).map_err(|e| Error::RequestError(e.to_string()))
263 }
264}
265
266#[derive(Clone)]
268pub struct Gemini {
269 client: Arc<GeminiClient>,
270}
271
272impl Gemini {
273 pub fn new(api_key: impl Into<String>) -> Self {
275 Self::with_model(api_key, DEFAULT_MODEL.to_string())
276 }
277
278 pub fn pro(api_key: impl Into<String>) -> Self {
280 Self::with_model(api_key, "models/gemini-2.5-pro".to_string())
281 }
282
283 pub fn with_model(api_key: impl Into<String>, model: String) -> Self {
285 Self::with_model_and_base_url(api_key, model, DEFAULT_BASE_URL.to_string())
286 }
287
288 pub fn with_base_url(api_key: impl Into<String>, base_url: String) -> Self {
290 Self::with_model_and_base_url(api_key, DEFAULT_MODEL.to_string(), base_url)
291 }
292
293 pub fn with_model_and_base_url(
295 api_key: impl Into<String>,
296 model: String,
297 base_url: String,
298 ) -> Self {
299 let client = GeminiClient::with_base_url(api_key.into(), model, base_url);
300 Self {
301 client: Arc::new(client),
302 }
303 }
304
305 pub fn generate_content(&self) -> ContentBuilder {
307 ContentBuilder::new(self.client.clone())
308 }
309
310 pub fn embed_content(&self) -> EmbedBuilder {
312 EmbedBuilder::new(self.client.clone())
313 }
314
315 pub fn batch_generate_content_sync(&self) -> BatchBuilder {
317 BatchBuilder::new(self.client.clone())
318 }
319
320 pub fn get_batch(&self, name: &str) -> Batch {
322 Batch::new(name.to_string(), self.client.clone())
323 }
324
325 pub fn list_batches(
329 &self,
330 page_size: impl Into<Option<u32>>,
331 ) -> impl Stream<Item = Result<BatchOperation>> + Send {
332 let client = self.client.clone();
333 let page_size = page_size.into();
334 async_stream::try_stream! {
335 let mut page_token: Option<String> = None;
336 loop {
337 let response = client
338 .list_batch_operations(page_size, page_token.clone())
339 .await?;
340
341 for operation in response.operations {
342 yield operation;
343 }
344
345 if let Some(next_page_token) = response.next_page_token {
346 page_token = Some(next_page_token);
347 } else {
348 break;
349 }
350 }
351 }
352 }
353}