google_generative_ai_rs/v1/
vertexai.rs1use std::{fmt, sync::Arc};
3
4use super::{
5 api::{Client, Url},
6 gemini::{Model, ResponseType},
7};
8use crate::v1::errors::GoogleAPIError;
9
10const VERTEX_AI_API_URL_BASE: &str = "https://{region}-aiplatform.googleapis.com/v1";
11
12const GCP_API_AUTH_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
13
14impl Client {
15 pub fn new_from_region_project_id(region: String, project_id: String) -> Self {
21 Client::new_from_region_project_id_response_type(
22 region,
23 project_id,
24 ResponseType::StreamGenerateContent,
25 )
26 }
27 pub fn new_from_region_project_id_response_type(
28 region: String,
29 project_id: String,
30 response_type: ResponseType,
31 ) -> Self {
32 let url = Url::new_from_region_project_id(
33 &Model::default(),
34 region.clone(),
35 project_id.clone(),
36 &response_type,
37 );
38 Self {
39 url: url.url,
40 model: Model::default(),
41 region: Some(region),
42 project_id: Some(project_id),
43 response_type,
44 }
45 }
46 pub fn new_from_model_region_project_id(
52 model: Model,
53 region: String,
54 project_id: String,
55 ) -> Self {
56 let url = Url::new_from_region_project_id(
57 &model,
58 region.clone(),
59 project_id.clone(),
60 &ResponseType::StreamGenerateContent,
61 );
62 Self {
63 url: url.url,
64 model,
65 region: Some(region),
66 project_id: Some(project_id),
67 response_type: ResponseType::StreamGenerateContent,
68 }
69 }
70
71 pub(crate) async fn get_auth_token_option(&self) -> Result<Option<String>, GoogleAPIError> {
73 let token_option = if self.project_id.is_some() && self.region.is_some() {
74 let token = self.get_gcp_authn_token().await?.as_str().to_string();
75 Some(token)
76 } else {
77 None
78 };
79 Ok(token_option)
80 }
81 async fn get_gcp_authn_token(&self) -> Result<Arc<gcp_auth::Token>, GoogleAPIError> {
83 let provider = gcp_auth::provider().await.map_err(|e| GoogleAPIError {
84 message: format!("Failed to create AuthenticationManager: {}", e),
85 code: None,
86 })?;
87 let scopes = &[GCP_API_AUTH_SCOPE];
88 let token = provider.token(scopes).await.map_err(|e| GoogleAPIError {
89 message: format!("Failed to generate authentication token: {}", e),
90 code: None,
91 })?;
92 Ok(token)
93 }
94}
95impl fmt::Display for Client {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 if self.region.is_some() && self.project_id.is_some() {
99 write!(
100 f,
101 "GenerativeAiClient {{ url: {:?}, model: {:?}, region: {:?}, project_id: {:?} }}",
102 self.url, self.model, self.region, self.project_id
103 )
104 } else {
105 write!(
106 f,
107 "GenerativeAiClient {{ url: {:?}, model: {:?}, region: {:?}, project_id: {:?} }}",
108 Url::new(
109 &self.model,
110 "*************".to_string(),
111 &self.response_type
112 ),
113 self.model,
114 self.region,
115 self.project_id
116 )
117 }
118 }
119}
120
121impl Url {
122 pub(crate) fn new_from_region_project_id(
123 model: &Model,
124 region: String,
125 project_id: String,
126 response_type: &ResponseType,
127 ) -> Self {
128 let base_url = VERTEX_AI_API_URL_BASE
129 .to_owned()
130 .replace("{region}", ®ion);
131
132 let url = format!(
133 "{}/projects/{}/locations/{}/publishers/google/models/{}:{}",
134 base_url, project_id, region, model, response_type,
135 );
136 Self { url }
137 }
138}
139#[cfg(test)]
140mod tests {
141 use crate::v1::{
142 api::{Client, Url},
143 gemini::{Model, ResponseType},
144 };
145
146 use super::*;
147
148 #[test]
149 fn test_new_from_region_project_id() {
150 let region = String::from("us-central1");
151 let project_id = String::from("my-project");
152 let client = Client::new_from_region_project_id(region.clone(), project_id.clone());
153
154 assert_eq!(client.region, Some(region));
155 assert_eq!(client.project_id, Some(project_id));
156 }
157
158 #[test]
159 fn test_new_from_model_region_project_id() {
160 let model = Model::default();
161 let region = String::from("us-central1");
162 let project_id = String::from("my-project");
163 let client = Client::new_from_model_region_project_id(
164 model.clone(),
165 region.clone(),
166 project_id.clone(),
167 );
168
169 assert_eq!(client.model, model);
170 assert_eq!(client.region, Some(region));
171 assert_eq!(client.project_id, Some(project_id));
172 }
173
174 #[test]
175 fn test_url_new_from_region_project_id() {
176 let model = Model::default();
177 let region = String::from("us-central1");
178 let project_id = String::from("my-project");
179 let url = Url::new_from_region_project_id(
180 &model,
181 region.clone(),
182 project_id.clone(),
183 &ResponseType::StreamGenerateContent,
184 );
185
186 assert_eq!(
187 url.url,
188 format!(
189 "{}/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent",
190 VERTEX_AI_API_URL_BASE.replace("{region}", ®ion),
191 project_id,
192 region,
193 model
194 )
195 );
196 }
197}