google_generative_ai_rs/v1/
api.rs1use futures::prelude::*;
3use futures::stream::StreamExt;
4use reqwest::StatusCode;
5use reqwest_streams::error::StreamBodyError;
6use reqwest_streams::*;
7use serde_json;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Mutex;
12
13use crate::v1::errors::GoogleAPIError;
14use crate::v1::gemini::request::Request;
15use crate::v1::gemini::response::GeminiResponse;
16use crate::v1::gemini::Model;
17
18use super::gemini::response::{GeminiErrorResponse, StreamedGeminiResponse, TokenCount};
19use super::gemini::{ModelInformation, ModelInformationList, ResponseType};
20
21#[cfg(feature = "beta")]
22const PUBLIC_API_URL_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
23
24#[cfg(not(feature = "beta"))]
25const PUBLIC_API_URL_BASE: &str = "https://generativelanguage.googleapis.com/v1";
26
27#[derive(Debug)]
29pub enum PostResult {
30 Rest(GeminiResponse),
31 Streamed(StreamedGeminiResponse),
32 Count(TokenCount),
33}
34impl PostResult {
35 pub fn rest(self) -> Option<GeminiResponse> {
36 match self {
37 PostResult::Rest(response) => Some(response),
38 _ => None,
39 }
40 }
41 pub fn streamed(self) -> Option<StreamedGeminiResponse> {
42 match self {
43 PostResult::Streamed(streamed_response) => Some(streamed_response),
44 _ => None,
45 }
46 }
47 pub fn count(self) -> Option<TokenCount> {
48 match self {
49 PostResult::Count(response) => Some(response),
50 _ => None,
51 }
52 }
53}
54
55pub struct Client {
57 pub url: String,
58 pub model: Model,
59 pub region: Option<String>,
60 pub project_id: Option<String>,
61 pub response_type: ResponseType,
62}
63
64impl Client {
68 pub fn new(api_key: String) -> Self {
70 let url = Url::new(&Model::default(), api_key, &ResponseType::GenerateContent);
71 Self {
72 url: url.url,
73 model: Model::default(),
74 region: None,
75 project_id: None,
76 response_type: ResponseType::GenerateContent,
77 }
78 }
79
80 pub fn new_from_response_type(response_type: ResponseType, api_key: String) -> Self {
82 let url = Url::new(&Model::default(), api_key, &response_type);
83 Self {
84 url: url.url,
85 model: Model::default(),
86 region: None,
87 project_id: None,
88 response_type,
89 }
90 }
91
92 pub fn new_from_model(model: Model, api_key: String) -> Self {
94 let url = Url::new(&model, api_key, &ResponseType::GenerateContent);
95 Self {
96 url: url.url,
97 model,
98 region: None,
99 project_id: None,
100 response_type: ResponseType::GenerateContent,
101 }
102 }
103
104 pub fn new_from_model_response_type(
106 model: Model,
107 api_key: String,
108 response_type: ResponseType,
109 ) -> Self {
110 let url = Url::new(&model, api_key, &response_type);
111 Self {
112 url: url.url,
113 model,
114 region: None,
115 project_id: None,
116 response_type,
117 }
118 }
119
120 pub async fn post(
122 &self,
123 timeout: u64,
124 api_request: &Request,
125 ) -> Result<PostResult, GoogleAPIError> {
126 let client: reqwest::Client = self.get_reqwest_client(timeout)?;
127 match self.response_type {
128 ResponseType::GenerateContent => {
129 let result = self.get_post_result(client, api_request).await?;
130 Ok(PostResult::Rest(result))
131 }
132 ResponseType::StreamGenerateContent => {
133 let result = self.get_streamed_post_result(client, api_request).await?;
134 Ok(PostResult::Streamed(result))
135 }
136 ResponseType::CountTokens => {
137 let result = self.get_token_count(client, api_request).await?;
138 Ok(PostResult::Count(result))
139 }
140 _ => Err(GoogleAPIError {
141 message: format!("Unsupported response type: {:?}", self.response_type),
142 code: None,
143 }),
144 }
145 }
146
147 async fn get_post_result(
149 &self,
150 client: reqwest::Client,
151 api_request: &Request,
152 ) -> Result<GeminiResponse, GoogleAPIError> {
153 let token_option = self.get_auth_token_option().await?;
154
155 let result = self
156 .get_post_response(client, api_request, token_option)
157 .await;
158
159 if let Ok(result) = result {
160 match result.status() {
161 reqwest::StatusCode::OK => {
162 Ok(result.json::<GeminiResponse>().await.map_err(|e|GoogleAPIError {
163 message: format!(
164 "Failed to deserialize API response into v1::gemini::response::GeminiResponse: {}",
165 e
166 ),
167 code: None,
168 })?)
169 },
170 _ => {
171 let status = result.status();
172
173 match result.json::<GeminiErrorResponse>().await {
174 Ok(GeminiErrorResponse::Error { message, .. }) => Err(self.new_error_from_api_message(status, message)),
175 Err(_) => Err(self.new_error_from_status_code(status)),
176 }
177 },
178 }
179 } else {
180 Err(self.new_error_from_reqwest_error(result.unwrap_err()))
181 }
182 }
183
184 async fn get_streamed_post_result(
187 &self,
188 client: reqwest::Client,
189 api_request: &Request,
190 ) -> Result<StreamedGeminiResponse, GoogleAPIError> {
191 let token_option = self.get_auth_token_option().await?;
192
193 let result = self
194 .get_post_response(client, api_request, token_option)
195 .await;
196
197 match result {
198 Ok(response) => match response.status() {
199 reqwest::StatusCode::OK => {
200 let json_stream = response.json_array_stream::<serde_json::Value>(2048); Ok(StreamedGeminiResponse {
204 response_stream: Some(json_stream),
205 })
206 }
207 _ => Err(self.new_error_from_status_code(response.status())),
208 },
209 Err(e) => Err(self.new_error_from_reqwest_error(e)),
210 }
211 }
212
213 pub async fn for_each_async<F, Fut>(
232 stream: Pin<Box<dyn Stream<Item = Result<serde_json::Value, StreamBodyError>> + Send>>,
233 consumer: F,
234 ) where
235 F: FnMut(GeminiResponse) -> Fut + Send + 'static,
236 Fut: Future<Output = ()>,
237 {
238 let consumer = Arc::new(Mutex::new(consumer));
240
241 stream
246 .for_each_concurrent(None, |item: Result<serde_json::Value, StreamBodyError>| {
247 let consumer = Arc::clone(&consumer);
248 async move {
249 let res = match item {
250 Ok(result) => {
251 Client::convert_json_value_to_response(&result).map_err(|e| {
252 GoogleAPIError {
253 message: format!(
254 "Failed to get JSON stream from request: {}",
255 e
256 ),
257 code: None,
258 }
259 })
260 }
261 Err(e) => Err(GoogleAPIError {
262 message: format!("Failed to get JSON stream from request: {}", e),
263 code: None,
264 }),
265 };
266
267 if let Ok(response) = res {
268 let mut consumer = consumer.lock().await;
269 consumer(response).await;
270 }
271 }
272 })
273 .await;
274 }
275
276 async fn get_post_response(
282 &self,
283 client: reqwest::Client,
284 api_request: &Request,
285 authn_token: Option<String>,
286 ) -> Result<reqwest::Response, reqwest::Error> {
287 let mut request_builder = client
288 .post(&self.url)
289 .header(reqwest::header::USER_AGENT, env!("CARGO_CRATE_NAME"))
290 .header(reqwest::header::CONTENT_TYPE, "application/json");
291
292 if let Some(token) = authn_token {
294 request_builder = request_builder.bearer_auth(token);
295 }
296
297 request_builder.json(&api_request).send().await
298 }
299 pub async fn get_token_count(
305 &self,
306 client: reqwest::Client,
307 api_request: &Request,
308 ) -> Result<TokenCount, GoogleAPIError> {
309 let token_option = self.get_auth_token_option().await?;
310
311 let result = self
312 .get_post_response(client, api_request, token_option)
313 .await;
314
315 match result {
316 Ok(response) => match response.status() {
317 reqwest::StatusCode::OK => Ok(response.json::<TokenCount>().await.map_err(|e|GoogleAPIError {
318 message: format!(
319 "Failed to deserialize API response into v1::gemini::response::TokenCount: {}",
320 e
321 ),
322 code: None,
323 })?),
324 _ => Err(self.new_error_from_status_code(response.status())),
325 },
326 Err(e) => Err(self.new_error_from_reqwest_error(e)),
327 }
328 }
329
330 async fn get(
332 &self,
333 timeout: u64,
334 ) -> Result<Result<reqwest::Response, reqwest::Error>, GoogleAPIError> {
335 let client: reqwest::Client = self.get_reqwest_client(timeout)?;
336 let result = client
337 .get(&self.url)
338 .header(reqwest::header::USER_AGENT, env!("CARGO_CRATE_NAME"))
339 .send()
340 .await;
341 Ok(result)
342 }
343 pub async fn get_model(&self, timeout: u64) -> Result<ModelInformation, GoogleAPIError> {
347 let result = self.get(timeout).await?;
348
349 match result {
350 Ok(response) => {
351 match response.status() {
352 reqwest::StatusCode::OK => Ok(response
353 .json::<ModelInformation>()
354 .await
355 .map_err(|e| GoogleAPIError {
356 message: format!(
357 "Failed to deserialize API response into v1::gemini::ModelInformation: {}",
358 e
359 ),
360 code: None,
361 })?),
362 _ => Err(self.new_error_from_status_code(response.status())),
363 }
364 }
365 Err(e) => Err(self.new_error_from_reqwest_error(e)),
366 }
367 }
368 pub async fn get_model_list(
372 &self,
373 timeout: u64,
374 ) -> Result<ModelInformationList, GoogleAPIError> {
375 let result = self.get(timeout).await?;
376
377 match result {
378 Ok(response) => {
379 match response.status() {
380 reqwest::StatusCode::OK => Ok(response
381 .json::<ModelInformationList>()
382 .await
383 .map_err(|e| GoogleAPIError {
384 message: format!(
385 "Failed to deserialize API response into Vec<v1::gemini::ModelInformationList>: {}",
386 e
387 ),
388 code: None,
389 })?),
390 _ => Err(self.new_error_from_status_code(response.status())),
391 }
392 }
393 Err(e) => Err(self.new_error_from_reqwest_error(e)),
394 }
395 }
396
397 fn convert_json_value_to_response(
404 json_value: &serde_json::Value,
405 ) -> Result<GeminiResponse, serde_json::error::Error> {
406 serde_json::from_value(json_value.clone())
407 }
408
409 fn get_reqwest_client(&self, timeout: u64) -> Result<reqwest::Client, GoogleAPIError> {
410 let client: reqwest::Client = reqwest::Client::builder()
411 .timeout(Duration::from_secs(timeout))
412 .build()
413 .map_err(|e| self.new_error_from_reqwest_error(e.without_url()))?;
414 Ok(client)
415 }
416 fn new_error_from_status_code(&self, code: reqwest::StatusCode) -> GoogleAPIError {
418 let status_text = code.canonical_reason().unwrap_or("Unknown Status");
419 let message = format!("HTTP Error: {}: {}", code.as_u16(), status_text);
420
421 GoogleAPIError {
422 message,
423 code: Some(code),
424 }
425 }
426
427 fn new_error_from_api_message(&self, code: StatusCode, message: String) -> GoogleAPIError {
429 let message = format!("API message: {message}.");
430
431 GoogleAPIError {
432 message,
433 code: Some(code),
434 }
435 }
436
437 fn new_error_from_reqwest_error(&self, mut e: reqwest::Error) -> GoogleAPIError {
439 if let Some(url) = e.url_mut() {
440 url.query_pairs_mut().clear();
442 }
443
444 GoogleAPIError {
445 message: format!("{}", e),
446 code: e.status(),
447 }
448 }
449}
450
451#[derive(Debug)]
456pub(crate) struct Url {
457 pub url: String,
458}
459impl Url {
460 pub(crate) fn new(model: &Model, api_key: String, response_type: &ResponseType) -> Self {
461 let base_url = PUBLIC_API_URL_BASE.to_owned();
462 match response_type {
463 ResponseType::GenerateContent => Self {
464 url: format!(
465 "{}/models/{}:{}?key={}",
466 base_url, model, response_type, api_key
467 ),
468 },
469 ResponseType::StreamGenerateContent => Self {
470 url: format!(
471 "{}/models/{}:{}?key={}",
472 base_url, model, response_type, api_key
473 ),
474 },
475 ResponseType::GetModel => Self {
476 url: format!("{}/models/{}?key={}", base_url, model, api_key),
477 },
478 ResponseType::GetModelList => Self {
479 url: format!("{}/models?key={}", base_url, api_key),
480 },
481 ResponseType::CountTokens => Self {
482 url: format!(
483 "{}/models/{}:{}?key={}",
484 base_url, model, response_type, api_key
485 ),
486 },
487 _ => panic!("Unsupported response type: {:?}", response_type),
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use reqwest::StatusCode;
496
497 #[test]
498 fn test_new_error_from_status_code() {
499 let client = Client::new("my-api-key".to_string());
500 let status_code = StatusCode::BAD_REQUEST;
501
502 let error = client.new_error_from_status_code(status_code);
503
504 assert_eq!(error.message, "HTTP Error: 400: Bad Request");
505 assert_eq!(error.code, Some(status_code));
506 }
507
508 #[test]
509 fn test_url_new() {
510 let model = Model::default();
511 let api_key = String::from("my-api-key");
512 let url = Url::new(&model, api_key.clone(), &ResponseType::GenerateContent);
513
514 assert_eq!(
515 url.url,
516 format!(
517 "{}/models/{}:generateContent?key={}",
518 PUBLIC_API_URL_BASE, model, api_key
519 )
520 );
521 }
522}