google_generative_ai_rs/v1/
api.rs

1//! Manages the interaction with the REST API for the Gemini API.
2use futures::prelude::*;
3use futures::stream::StreamExt;
4use reqwest::StatusCode;
5use reqwest_streams::error::StreamBodyError;
6use reqwest_streams::*;
7use serde_json;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::Mutex;
12
13use crate::v1::errors::GoogleAPIError;
14use crate::v1::gemini::request::Request;
15use crate::v1::gemini::response::GeminiResponse;
16use crate::v1::gemini::Model;
17
18use super::gemini::response::{GeminiErrorResponse, StreamedGeminiResponse, TokenCount};
19use super::gemini::{ModelInformation, ModelInformationList, ResponseType};
20
21#[cfg(feature = "beta")]
22const PUBLIC_API_URL_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
23
24#[cfg(not(feature = "beta"))]
25const PUBLIC_API_URL_BASE: &str = "https://generativelanguage.googleapis.com/v1";
26
27/// Enables a streamed or non-streamed response to be returned from the API.
28#[derive(Debug)]
29pub enum PostResult {
30    Rest(GeminiResponse),
31    Streamed(StreamedGeminiResponse),
32    Count(TokenCount),
33}
34impl PostResult {
35    pub fn rest(self) -> Option<GeminiResponse> {
36        match self {
37            PostResult::Rest(response) => Some(response),
38            _ => None,
39        }
40    }
41    pub fn streamed(self) -> Option<StreamedGeminiResponse> {
42        match self {
43            PostResult::Streamed(streamed_response) => Some(streamed_response),
44            _ => None,
45        }
46    }
47    pub fn count(self) -> Option<TokenCount> {
48        match self {
49            PostResult::Count(response) => Some(response),
50            _ => None,
51        }
52    }
53}
54
55/// Manages the specific API connection
56pub struct Client {
57    pub url: String,
58    pub model: Model,
59    pub region: Option<String>,
60    pub project_id: Option<String>,
61    pub response_type: ResponseType,
62}
63
64/// Implements the functions for the API client.
65/// TODO: This is getting unwieldy. We need to refactor this into a more manageable state.
66///         See Issue #26 - 'Code tidy and improvement'
67impl Client {
68    /// Creates a default new public API client.
69    pub fn new(api_key: String) -> Self {
70        let url = Url::new(&Model::default(), api_key, &ResponseType::GenerateContent);
71        Self {
72            url: url.url,
73            model: Model::default(),
74            region: None,
75            project_id: None,
76            response_type: ResponseType::GenerateContent,
77        }
78    }
79
80    /// Creates a default new public API client for a specified response type.
81    pub fn new_from_response_type(response_type: ResponseType, api_key: String) -> Self {
82        let url = Url::new(&Model::default(), api_key, &response_type);
83        Self {
84            url: url.url,
85            model: Model::default(),
86            region: None,
87            project_id: None,
88            response_type,
89        }
90    }
91
92    /// Create a new public API client for a specified model.
93    pub fn new_from_model(model: Model, api_key: String) -> Self {
94        let url = Url::new(&model, api_key, &ResponseType::GenerateContent);
95        Self {
96            url: url.url,
97            model,
98            region: None,
99            project_id: None,
100            response_type: ResponseType::GenerateContent,
101        }
102    }
103
104    /// Create a new public API client for a specified model.
105    pub fn new_from_model_response_type(
106        model: Model,
107        api_key: String,
108        response_type: ResponseType,
109    ) -> Self {
110        let url = Url::new(&model, api_key, &response_type);
111        Self {
112            url: url.url,
113            model,
114            region: None,
115            project_id: None,
116            response_type,
117        }
118    }
119
120    // post
121    pub async fn post(
122        &self,
123        timeout: u64,
124        api_request: &Request,
125    ) -> Result<PostResult, GoogleAPIError> {
126        let client: reqwest::Client = self.get_reqwest_client(timeout)?;
127        match self.response_type {
128            ResponseType::GenerateContent => {
129                let result = self.get_post_result(client, api_request).await?;
130                Ok(PostResult::Rest(result))
131            }
132            ResponseType::StreamGenerateContent => {
133                let result = self.get_streamed_post_result(client, api_request).await?;
134                Ok(PostResult::Streamed(result))
135            }
136            ResponseType::CountTokens => {
137                let result = self.get_token_count(client, api_request).await?;
138                Ok(PostResult::Count(result))
139            }
140            _ => Err(GoogleAPIError {
141                message: format!("Unsupported response type: {:?}", self.response_type),
142                code: None,
143            }),
144        }
145    }
146
147    /// A standard post request, i.e., not streamed
148    async fn get_post_result(
149        &self,
150        client: reqwest::Client,
151        api_request: &Request,
152    ) -> Result<GeminiResponse, GoogleAPIError> {
153        let token_option = self.get_auth_token_option().await?;
154
155        let result = self
156            .get_post_response(client, api_request, token_option)
157            .await;
158
159        if let Ok(result) = result {
160            match result.status() {
161                reqwest::StatusCode::OK => {
162                    Ok(result.json::<GeminiResponse>().await.map_err(|e|GoogleAPIError {
163                    message: format!(
164                            "Failed to deserialize API response into v1::gemini::response::GeminiResponse: {}",
165                            e
166                        ),
167                    code: None,
168                    })?)
169                },
170                _ => {
171                    let status = result.status();
172
173                    match result.json::<GeminiErrorResponse>().await {
174                        Ok(GeminiErrorResponse::Error { message, .. }) => Err(self.new_error_from_api_message(status, message)),
175                        Err(_) => Err(self.new_error_from_status_code(status)),
176                    }
177                },
178            }
179        } else {
180            Err(self.new_error_from_reqwest_error(result.unwrap_err()))
181        }
182    }
183
184    // Define the function that accepts the stream and the consumer
185    /// A streamed post request
186    async fn get_streamed_post_result(
187        &self,
188        client: reqwest::Client,
189        api_request: &Request,
190    ) -> Result<StreamedGeminiResponse, GoogleAPIError> {
191        let token_option = self.get_auth_token_option().await?;
192
193        let result = self
194            .get_post_response(client, api_request, token_option)
195            .await;
196
197        match result {
198            Ok(response) => match response.status() {
199                reqwest::StatusCode::OK => {
200                    // Wire to enable introspection on the response stream
201                    let json_stream = response.json_array_stream::<serde_json::Value>(2048); //TODO what is a good length?;
202
203                    Ok(StreamedGeminiResponse {
204                        response_stream: Some(json_stream),
205                    })
206                }
207                _ => Err(self.new_error_from_status_code(response.status())),
208            },
209            Err(e) => Err(self.new_error_from_reqwest_error(e)),
210        }
211    }
212
213    /// Applies an asynchronous operation to each item in a stream, potentially concurrently.
214    ///
215    /// This function retrieves each item from the provided stream, processes it using the given
216    /// consumer callback, and awaits the futures produced by the consumer. The concurrency level
217    /// is unbounded, meaning items will be processed as soon as they are ready without a limit.
218    ///
219    /// # Type Parameters
220    ///
221    /// - `F`: The type of the consumer closure. It must accept a `GeminiResponse` and return a future.
222    /// - `Fut`: The future type returned by the `consumer` closure. It must resolve to `()`.
223    ///
224    /// # Parameters
225    ///
226    /// - `stream`: A `Pin<Box<dyn Stream>>` that produces items of type `Result<serde_json::Value, StreamBodyError>`.
227    ///   The stream already needs to be pinned and boxed when passed into this function.
228    /// - `consumer`: A mutable closure that is called for each `GeminiResponse`. The results of the
229    ///   closure are futures which will be awaited to completion. This closure needs to be `Send` and
230    ///   `'static` to allow for concurrent and potentially multi-threaded execution.
231    pub async fn for_each_async<F, Fut>(
232        stream: Pin<Box<dyn Stream<Item = Result<serde_json::Value, StreamBodyError>> + Send>>,
233        consumer: F,
234    ) where
235        F: FnMut(GeminiResponse) -> Fut + Send + 'static,
236        Fut: Future<Output = ()>,
237    {
238        // Since the stream is already boxed and pinned, you can directly use it
239        let consumer = Arc::new(Mutex::new(consumer));
240
241        // Use the for_each_concurrent method to apply the consumer to each item
242        // in the stream, handling each item as it's ready. Set `None` for unbounded concurrency,
243        // or set a limit with `Some(n)`
244
245        stream
246            .for_each_concurrent(None, |item: Result<serde_json::Value, StreamBodyError>| {
247                let consumer = Arc::clone(&consumer);
248                async move {
249                    let res = match item {
250                        Ok(result) => {
251                            Client::convert_json_value_to_response(&result).map_err(|e| {
252                                GoogleAPIError {
253                                    message: format!(
254                                        "Failed to get JSON stream from request: {}",
255                                        e
256                                    ),
257                                    code: None,
258                                }
259                            })
260                        }
261                        Err(e) => Err(GoogleAPIError {
262                            message: format!("Failed to get JSON stream from request: {}", e),
263                            code: None,
264                        }),
265                    };
266
267                    if let Ok(response) = res {
268                        let mut consumer = consumer.lock().await;
269                        consumer(response).await;
270                    }
271                }
272            })
273            .await;
274    }
275
276    /// Gets a ['reqwest::GeminiResponse'] from a post request.
277    /// Parameters:
278    /// * client - the ['reqwest::Client'] to use
279    /// * api_request - the ['Request'] to send
280    /// * authn_token - an optional authn token to use
281    async fn get_post_response(
282        &self,
283        client: reqwest::Client,
284        api_request: &Request,
285        authn_token: Option<String>,
286    ) -> Result<reqwest::Response, reqwest::Error> {
287        let mut request_builder = client
288            .post(&self.url)
289            .header(reqwest::header::USER_AGENT, env!("CARGO_CRATE_NAME"))
290            .header(reqwest::header::CONTENT_TYPE, "application/json");
291
292        // If a GCP authn token is provided, use it
293        if let Some(token) = authn_token {
294            request_builder = request_builder.bearer_auth(token);
295        }
296
297        request_builder.json(&api_request).send().await
298    }
299    // Count Tokens - see: "https://ai.google.dev/tutorials/rest_quickstart#count_tokens"
300    //
301    /// Parameters:
302    /// * timeout - the timeout in seconds
303    /// * api_request - the request to send to check token count
304    pub async fn get_token_count(
305        &self,
306        client: reqwest::Client,
307        api_request: &Request,
308    ) -> Result<TokenCount, GoogleAPIError> {
309        let token_option = self.get_auth_token_option().await?;
310
311        let result = self
312            .get_post_response(client, api_request, token_option)
313            .await;
314
315        match result {
316            Ok(response) => match response.status() {
317                reqwest::StatusCode::OK => Ok(response.json::<TokenCount>().await.map_err(|e|GoogleAPIError {
318                message: format!(
319                        "Failed to deserialize API response into v1::gemini::response::TokenCount: {}",
320                        e
321                    ),
322                code: None,
323            })?),
324                _ => Err(self.new_error_from_status_code(response.status())),
325            },
326            Err(e) => Err(self.new_error_from_reqwest_error(e)),
327        }
328    }
329
330    /// Get for the url specified in 'self'
331    async fn get(
332        &self,
333        timeout: u64,
334    ) -> Result<Result<reqwest::Response, reqwest::Error>, GoogleAPIError> {
335        let client: reqwest::Client = self.get_reqwest_client(timeout)?;
336        let result = client
337            .get(&self.url)
338            .header(reqwest::header::USER_AGENT, env!("CARGO_CRATE_NAME"))
339            .send()
340            .await;
341        Ok(result)
342    }
343    /// Gets a model - see: "https://ai.google.dev/tutorials/rest_quickstart#get_model"
344    /// Parameters:
345    /// * timeout - the timeout in seconds
346    pub async fn get_model(&self, timeout: u64) -> Result<ModelInformation, GoogleAPIError> {
347        let result = self.get(timeout).await?;
348
349        match result {
350            Ok(response) => {
351                match response.status() {
352                    reqwest::StatusCode::OK => Ok(response
353                        .json::<ModelInformation>()
354                        .await
355                        .map_err(|e| GoogleAPIError {
356                            message: format!(
357                        "Failed to deserialize API response into v1::gemini::ModelInformation: {}",
358                        e
359                    ),
360                            code: None,
361                        })?),
362                    _ => Err(self.new_error_from_status_code(response.status())),
363                }
364            }
365            Err(e) => Err(self.new_error_from_reqwest_error(e)),
366        }
367    }
368    /// Gets a list of models - see: "https://ai.google.dev/tutorials/rest_quickstart#list_models"
369    /// Parameters:
370    /// * timeout - the timeout in seconds
371    pub async fn get_model_list(
372        &self,
373        timeout: u64,
374    ) -> Result<ModelInformationList, GoogleAPIError> {
375        let result = self.get(timeout).await?;
376
377        match result {
378            Ok(response) => {
379                match response.status() {
380                    reqwest::StatusCode::OK => Ok(response
381                        .json::<ModelInformationList>()
382                        .await
383                        .map_err(|e| GoogleAPIError {
384                            message: format!(
385                        "Failed to deserialize API response into Vec<v1::gemini::ModelInformationList>: {}",
386                        e
387                    ),
388                        code: None,
389                    })?),
390                    _ => Err(self.new_error_from_status_code(response.status())),
391                }
392            }
393            Err(e) => Err(self.new_error_from_reqwest_error(e)),
394        }
395    }
396
397    // TODO function - see "https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/function-calling"
398
399    // TODO embedContent - see: "https://ai.google.dev/tutorials/rest_quickstart#embedding"
400
401    /// The current version of the Vertex API only supports streamed responses, so
402    /// in order to handle any issues we use a serde_json::Value and then convert to a Gemini [`Candidate`].
403    fn convert_json_value_to_response(
404        json_value: &serde_json::Value,
405    ) -> Result<GeminiResponse, serde_json::error::Error> {
406        serde_json::from_value(json_value.clone())
407    }
408
409    fn get_reqwest_client(&self, timeout: u64) -> Result<reqwest::Client, GoogleAPIError> {
410        let client: reqwest::Client = reqwest::Client::builder()
411            .timeout(Duration::from_secs(timeout))
412            .build()
413            .map_err(|e| self.new_error_from_reqwest_error(e.without_url()))?;
414        Ok(client)
415    }
416    /// Creates a new error from a status code.
417    fn new_error_from_status_code(&self, code: reqwest::StatusCode) -> GoogleAPIError {
418        let status_text = code.canonical_reason().unwrap_or("Unknown Status");
419        let message = format!("HTTP Error: {}: {}", code.as_u16(), status_text);
420
421        GoogleAPIError {
422            message,
423            code: Some(code),
424        }
425    }
426
427    /// Creates a new error from a status code.
428    fn new_error_from_api_message(&self, code: StatusCode, message: String) -> GoogleAPIError {
429        let message = format!("API message: {message}.");
430
431        GoogleAPIError {
432            message,
433            code: Some(code),
434        }
435    }
436
437    /// Creates a new error from a reqwest error.
438    fn new_error_from_reqwest_error(&self, mut e: reqwest::Error) -> GoogleAPIError {
439        if let Some(url) = e.url_mut() {
440            // Remove the API key from the URL, if any
441            url.query_pairs_mut().clear();
442        }
443
444        GoogleAPIError {
445            message: format!("{}", e),
446            code: e.status(),
447        }
448    }
449}
450
451/// There are two different URLs for the API, depending on whether the model is public or private.
452/// Authn for public models is via an API key, while authn for private models is via application default credentials (ADC).
453/// The public API URL is in the form of: https://generativelanguage.googleapis.com/v1/models/{model}:{generateContent|streamGenerateContent}
454/// The Vertex AI API URL is in the form of: https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/google/models/{model}:{streamGenerateContent}
455#[derive(Debug)]
456pub(crate) struct Url {
457    pub url: String,
458}
459impl Url {
460    pub(crate) fn new(model: &Model, api_key: String, response_type: &ResponseType) -> Self {
461        let base_url = PUBLIC_API_URL_BASE.to_owned();
462        match response_type {
463            ResponseType::GenerateContent => Self {
464                url: format!(
465                    "{}/models/{}:{}?key={}",
466                    base_url, model, response_type, api_key
467                ),
468            },
469            ResponseType::StreamGenerateContent => Self {
470                url: format!(
471                    "{}/models/{}:{}?key={}",
472                    base_url, model, response_type, api_key
473                ),
474            },
475            ResponseType::GetModel => Self {
476                url: format!("{}/models/{}?key={}", base_url, model, api_key),
477            },
478            ResponseType::GetModelList => Self {
479                url: format!("{}/models?key={}", base_url, api_key),
480            },
481            ResponseType::CountTokens => Self {
482                url: format!(
483                    "{}/models/{}:{}?key={}",
484                    base_url, model, response_type, api_key
485                ),
486            },
487            _ => panic!("Unsupported response type: {:?}", response_type),
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use reqwest::StatusCode;
496
497    #[test]
498    fn test_new_error_from_status_code() {
499        let client = Client::new("my-api-key".to_string());
500        let status_code = StatusCode::BAD_REQUEST;
501
502        let error = client.new_error_from_status_code(status_code);
503
504        assert_eq!(error.message, "HTTP Error: 400: Bad Request");
505        assert_eq!(error.code, Some(status_code));
506    }
507
508    #[test]
509    fn test_url_new() {
510        let model = Model::default();
511        let api_key = String::from("my-api-key");
512        let url = Url::new(&model, api_key.clone(), &ResponseType::GenerateContent);
513
514        assert_eq!(
515            url.url,
516            format!(
517                "{}/models/{}:generateContent?key={}",
518                PUBLIC_API_URL_BASE, model, api_key
519            )
520        );
521    }
522}