google_generative_ai_rs/v1/
vertexai.rs

1//! Contains logic and types specific to the Vertex AI endpoint (opposed to the public Gemini API endpoint)
2use std::{fmt, sync::Arc};
3
4use super::{
5    api::{Client, Url},
6    gemini::{Model, ResponseType},
7};
8use crate::v1::errors::GoogleAPIError;
9
10const VERTEX_AI_API_URL_BASE: &str = "https://{region}-aiplatform.googleapis.com/v1";
11
12const GCP_API_AUTH_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
13
14impl Client {
15    /// Create a new private API client (Vertex AI) using the default model, `Gemini-pro`.
16    ///
17    /// Parameters:
18    /// * region - the GCP region to use
19    /// * project_id - the GCP account project_id to use
20    pub fn new_from_region_project_id(region: String, project_id: String) -> Self {
21        Client::new_from_region_project_id_response_type(
22            region,
23            project_id,
24            ResponseType::StreamGenerateContent,
25        )
26    }
27    pub fn new_from_region_project_id_response_type(
28        region: String,
29        project_id: String,
30        response_type: ResponseType,
31    ) -> Self {
32        let url = Url::new_from_region_project_id(
33            &Model::default(),
34            region.clone(),
35            project_id.clone(),
36            &response_type,
37        );
38        Self {
39            url: url.url,
40            model: Model::default(),
41            region: Some(region),
42            project_id: Some(project_id),
43            response_type,
44        }
45    }
46    /// Create a new private API client.
47    /// Parameters:
48    /// * model - the Gemini model to use
49    /// * region - the GCP region to use
50    /// * project_id - the GCP account project_id to use
51    pub fn new_from_model_region_project_id(
52        model: Model,
53        region: String,
54        project_id: String,
55    ) -> Self {
56        let url = Url::new_from_region_project_id(
57            &model,
58            region.clone(),
59            project_id.clone(),
60            &ResponseType::StreamGenerateContent,
61        );
62        Self {
63            url: url.url,
64            model,
65            region: Some(region),
66            project_id: Some(project_id),
67            response_type: ResponseType::StreamGenerateContent,
68        }
69    }
70
71    /// If this is a Vertex AI request, get the token from the GCP authn library, if it is correctly configured, else None.
72    pub(crate) async fn get_auth_token_option(&self) -> Result<Option<String>, GoogleAPIError> {
73        let token_option = if self.project_id.is_some() && self.region.is_some() {
74            let token = self.get_gcp_authn_token().await?.as_str().to_string();
75            Some(token)
76        } else {
77            None
78        };
79        Ok(token_option)
80    }
81    /// Gets a GCP authn token.
82    async fn get_gcp_authn_token(&self) -> Result<Arc<gcp_auth::Token>, GoogleAPIError> {
83        let provider = gcp_auth::provider().await.map_err(|e| GoogleAPIError {
84            message: format!("Failed to create AuthenticationManager: {}", e),
85            code: None,
86        })?;
87        let scopes = &[GCP_API_AUTH_SCOPE];
88        let token = provider.token(scopes).await.map_err(|e| GoogleAPIError {
89            message: format!("Failed to generate authentication token: {}", e),
90            code: None,
91        })?;
92        Ok(token)
93    }
94}
95/// Ensuring there is no leakage of secrets
96impl fmt::Display for Client {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        if self.region.is_some() && self.project_id.is_some() {
99            write!(
100                f,
101                "GenerativeAiClient {{ url: {:?}, model: {:?}, region: {:?}, project_id: {:?} }}",
102                self.url, self.model, self.region, self.project_id
103            )
104        } else {
105            write!(
106                f,
107                "GenerativeAiClient {{ url: {:?}, model: {:?}, region: {:?}, project_id: {:?} }}",
108                Url::new(
109                    &self.model,
110                    "*************".to_string(),
111                    &self.response_type
112                ),
113                self.model,
114                self.region,
115                self.project_id
116            )
117        }
118    }
119}
120
121impl Url {
122    pub(crate) fn new_from_region_project_id(
123        model: &Model,
124        region: String,
125        project_id: String,
126        response_type: &ResponseType,
127    ) -> Self {
128        let base_url = VERTEX_AI_API_URL_BASE
129            .to_owned()
130            .replace("{region}", &region);
131
132        let url = format!(
133            "{}/projects/{}/locations/{}/publishers/google/models/{}:{}",
134            base_url, project_id, region, model, response_type,
135        );
136        Self { url }
137    }
138}
139#[cfg(test)]
140mod tests {
141    use crate::v1::{
142        api::{Client, Url},
143        gemini::{Model, ResponseType},
144    };
145
146    use super::*;
147
148    #[test]
149    fn test_new_from_region_project_id() {
150        let region = String::from("us-central1");
151        let project_id = String::from("my-project");
152        let client = Client::new_from_region_project_id(region.clone(), project_id.clone());
153
154        assert_eq!(client.region, Some(region));
155        assert_eq!(client.project_id, Some(project_id));
156    }
157
158    #[test]
159    fn test_new_from_model_region_project_id() {
160        let model = Model::default();
161        let region = String::from("us-central1");
162        let project_id = String::from("my-project");
163        let client = Client::new_from_model_region_project_id(
164            model.clone(),
165            region.clone(),
166            project_id.clone(),
167        );
168
169        assert_eq!(client.model, model);
170        assert_eq!(client.region, Some(region));
171        assert_eq!(client.project_id, Some(project_id));
172    }
173
174    #[test]
175    fn test_url_new_from_region_project_id() {
176        let model = Model::default();
177        let region = String::from("us-central1");
178        let project_id = String::from("my-project");
179        let url = Url::new_from_region_project_id(
180            &model,
181            region.clone(),
182            project_id.clone(),
183            &ResponseType::StreamGenerateContent,
184        );
185
186        assert_eq!(
187            url.url,
188            format!(
189                "{}/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent",
190                VERTEX_AI_API_URL_BASE.replace("{region}", &region),
191                project_id,
192                region,
193                model
194            )
195        );
196    }
197}