1use crate::completion::CompletionModel;
2use google_cloud_aiplatform_v1 as vertexai;
3use google_cloud_auth::credentials;
4use google_cloud_auth::credentials::Credentials;
5use rig::client::{CompletionClient, Nothing};
6use rig::prelude::*;
7use std::sync::Arc;
8use thiserror::Error;
9use tokio::sync::OnceCell;
10
11pub const DEFAULT_LOCATION: &str = "global";
21
22#[derive(Clone, Debug, Error)]
23pub enum VertexAiClientError {
24 #[error(
25 "Google Cloud project is required. Set it via `ClientBuilder::with_project()` or `GOOGLE_CLOUD_PROJECT`"
26 )]
27 MissingProject,
28 #[error("failed to build source credentials: {0}")]
29 SourceCredentials(String),
30 #[error("failed to build impersonated credentials: {0}")]
31 ImpersonatedCredentials(String),
32 #[error("failed to build Vertex AI prediction service: {0}")]
33 PredictionService(String),
34 #[error(
35 "Vertex AI uses Application Default Credentials (ADC). Use `Client::from_env()` for default credentials or `Client::builder().with_credentials(...).build()` for explicit credentials."
36 )]
37 InvalidInput,
38}
39
40fn build_credentials(
42 explicit_creds: Option<Credentials>,
43) -> Result<Credentials, VertexAiClientError> {
44 if let Some(creds) = explicit_creds {
45 Ok(creds)
46 } else {
47 let source_credentials = credentials::Builder::default()
49 .build()
50 .map_err(|e| VertexAiClientError::SourceCredentials(e.to_string()))?;
51
52 if let Ok(service_account) = std::env::var("GOOGLE_CLOUD_SERVICE_ACCOUNT") {
54 credentials::impersonated::Builder::from_source_credentials(source_credentials)
55 .with_target_principal(service_account)
56 .build()
57 .map_err(|e| VertexAiClientError::ImpersonatedCredentials(e.to_string()))
58 } else {
59 Ok(source_credentials)
60 }
61 }
62}
63
64#[derive(Clone, Debug)]
65pub struct ClientBuilder {
66 project: Option<String>,
67 location: Option<String>,
68 credentials: Option<Credentials>,
69}
70
71impl ClientBuilder {
72 pub fn new() -> Self {
73 Self {
74 project: None,
75 location: None,
76 credentials: None,
77 }
78 }
79
80 pub fn with_project(mut self, project: &str) -> Self {
84 self.project = Some(project.to_string());
85 self
86 }
87
88 pub fn with_location(mut self, location: &str) -> Self {
93 self.location = Some(location.to_string());
94 self
95 }
96
97 pub fn with_credentials(mut self, credentials: Credentials) -> Self {
102 self.credentials = Some(credentials);
103 self
104 }
105
106 pub fn build(self) -> Result<Client, VertexAiClientError> {
110 let project = self
111 .project
112 .or_else(|| std::env::var("GOOGLE_CLOUD_PROJECT").ok())
113 .ok_or(VertexAiClientError::MissingProject)?;
114
115 let location = self
116 .location
117 .or_else(|| std::env::var("GOOGLE_CLOUD_LOCATION").ok())
118 .unwrap_or_else(|| DEFAULT_LOCATION.to_string());
119
120 let credentials = build_credentials(self.credentials)?;
121
122 Ok(Client {
123 project,
124 location,
125 credentials,
126 vertex_client: Arc::new(OnceCell::new()),
127 })
128 }
129}
130
131impl Default for ClientBuilder {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137#[derive(Clone, Debug)]
138pub struct Client {
139 project: String,
140 location: String,
141 credentials: Credentials,
142 pub(crate) vertex_client:
143 Arc<OnceCell<Result<vertexai::client::PredictionService, VertexAiClientError>>>,
144}
145
146impl Client {
147 pub fn builder() -> ClientBuilder {
173 ClientBuilder::new()
174 }
175
176 pub fn new() -> Result<Self, VertexAiClientError> {
184 ClientBuilder::new().build()
185 }
186
187 pub fn from_env() -> Result<Self, VertexAiClientError> {
195 <Self as ProviderClient>::from_env()
196 }
197
198 pub fn project(&self) -> &str {
199 &self.project
200 }
201
202 pub fn location(&self) -> &str {
203 &self.location
204 }
205
206 pub async fn get_inner(
207 &self,
208 ) -> Result<&vertexai::client::PredictionService, VertexAiClientError> {
209 let credentials = self.credentials.clone();
210 self.vertex_client
211 .get_or_init(|| async {
212 let mut builder = vertexai::client::PredictionService::builder();
213 builder = builder.with_credentials(credentials);
214 builder
215 .build()
216 .await
217 .map_err(|error| VertexAiClientError::PredictionService(error.to_string()))
218 })
219 .await
220 .as_ref()
221 .map_err(Clone::clone)
222 }
223}
224
225impl ProviderClient for Client {
226 type Input = Nothing;
227 type Error = VertexAiClientError;
228
229 fn from_env() -> Result<Self, Self::Error>
230 where
231 Self: Sized,
232 {
233 Client::new()
234 }
235
236 fn from_val(_: Self::Input) -> Result<Self, Self::Error>
237 where
238 Self: Sized,
239 {
240 Err(VertexAiClientError::InvalidInput)
241 }
242}
243
244impl CompletionClient for Client {
245 type CompletionModel = CompletionModel;
246
247 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
248 CompletionModel::new(self.clone(), model.into())
249 }
250}
251
252impl VerifyClient for Client {
253 async fn verify(&self) -> Result<(), VerifyError> {
254 Ok(())
256 }
257}