Skip to main content

rig_vertexai/
client.rs

1use crate::completion::CompletionModel;
2use google_cloud_aiplatform_v1 as vertexai;
3use google_cloud_auth::credentials;
4use google_cloud_auth::credentials::Credentials;
5use rig::client::{CompletionClient, Nothing};
6use rig::prelude::*;
7use std::sync::Arc;
8use thiserror::Error;
9use tokio::sync::OnceCell;
10
11// Env vars and terminology (location, project) chosen to match google genai client
12// https://googleapis.github.io/python-genai/genai.html#genai.client.Client
13
14/// Default location for Vertex AI Gemini models.
15///
16/// The `global` endpoint is recommended for Gemini models as it provides higher availability
17/// and reduces resource exhaustion errors. Regional endpoints (e.g., `us-central1`, `europe-west4`)
18/// are also supported and can be specified via `ClientBuilder::with_location()`.
19/// Regional endpoints may be preferred for data residency requirements or to use regional quotas.
20pub const DEFAULT_LOCATION: &str = "global";
21
22#[derive(Clone, Debug, Error)]
23pub enum VertexAiClientError {
24    #[error(
25        "Google Cloud project is required. Set it via `ClientBuilder::with_project()` or `GOOGLE_CLOUD_PROJECT`"
26    )]
27    MissingProject,
28    #[error("failed to build source credentials: {0}")]
29    SourceCredentials(String),
30    #[error("failed to build impersonated credentials: {0}")]
31    ImpersonatedCredentials(String),
32    #[error("failed to build Vertex AI prediction service: {0}")]
33    PredictionService(String),
34    #[error(
35        "Vertex AI uses Application Default Credentials (ADC). Use `Client::from_env()` for default credentials or `Client::builder().with_credentials(...).build()` for explicit credentials."
36    )]
37    InvalidInput,
38}
39
40/// Helper function to build credentials with optional service account impersonation.
41fn build_credentials(
42    explicit_creds: Option<Credentials>,
43) -> Result<Credentials, VertexAiClientError> {
44    if let Some(creds) = explicit_creds {
45        Ok(creds)
46    } else {
47        // Build default credentials
48        let source_credentials = credentials::Builder::default()
49            .build()
50            .map_err(|e| VertexAiClientError::SourceCredentials(e.to_string()))?;
51
52        // Check for service account impersonation
53        if let Ok(service_account) = std::env::var("GOOGLE_CLOUD_SERVICE_ACCOUNT") {
54            credentials::impersonated::Builder::from_source_credentials(source_credentials)
55                .with_target_principal(service_account)
56                .build()
57                .map_err(|e| VertexAiClientError::ImpersonatedCredentials(e.to_string()))
58        } else {
59            Ok(source_credentials)
60        }
61    }
62}
63
64#[derive(Clone, Debug)]
65pub struct ClientBuilder {
66    project: Option<String>,
67    location: Option<String>,
68    credentials: Option<Credentials>,
69}
70
71impl ClientBuilder {
72    pub fn new() -> Self {
73        Self {
74            project: None,
75            location: None,
76            credentials: None,
77        }
78    }
79
80    /// Set the Google Cloud project ID explicitly.
81    ///
82    /// If not set, will fall back to `GOOGLE_CLOUD_PROJECT` environment variable.
83    pub fn with_project(mut self, project: &str) -> Self {
84        self.project = Some(project.to_string());
85        self
86    }
87
88    /// Set the Google Cloud location explicitly.
89    ///
90    /// If not set, will fall back to `GOOGLE_CLOUD_LOCATION` environment variable,
91    /// or default to "global" if the env var is also not set.
92    pub fn with_location(mut self, location: &str) -> Self {
93        self.location = Some(location.to_string());
94        self
95    }
96
97    /// Set credentials explicitly.
98    ///
99    /// If not set, will build credentials from Application Default Credentials (ADC),
100    /// with optional service account impersonation if `GOOGLE_CLOUD_SERVICE_ACCOUNT` is set.
101    pub fn with_credentials(mut self, credentials: Credentials) -> Self {
102        self.credentials = Some(credentials);
103        self
104    }
105
106    /// Build the client with the configured values, falling back to environment variables where not set.
107    ///
108    /// The Vertex AI client is built lazily on first use via `get_inner()`.
109    pub fn build(self) -> Result<Client, VertexAiClientError> {
110        let project = self
111            .project
112            .or_else(|| std::env::var("GOOGLE_CLOUD_PROJECT").ok())
113            .ok_or(VertexAiClientError::MissingProject)?;
114
115        let location = self
116            .location
117            .or_else(|| std::env::var("GOOGLE_CLOUD_LOCATION").ok())
118            .unwrap_or_else(|| DEFAULT_LOCATION.to_string());
119
120        let credentials = build_credentials(self.credentials)?;
121
122        Ok(Client {
123            project,
124            location,
125            credentials,
126            vertex_client: Arc::new(OnceCell::new()),
127        })
128    }
129}
130
131impl Default for ClientBuilder {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137#[derive(Clone, Debug)]
138pub struct Client {
139    project: String,
140    location: String,
141    credentials: Credentials,
142    pub(crate) vertex_client:
143        Arc<OnceCell<Result<vertexai::client::PredictionService, VertexAiClientError>>>,
144}
145
146impl Client {
147    /// Create a new client builder that uses environment variables as defaults.
148    ///
149    /// You can override any values using the builder methods:
150    /// - `.with_project()` - override project
151    /// - `.with_location()` - override location
152    /// - `.with_credentials()` - override credentials
153    ///
154    /// Example:
155    /// ```no_run
156    /// # use rig_vertexai::Client;
157    /// # fn example() -> Result<(), rig_vertexai::client::VertexAiClientError> {
158    /// // Use all env vars
159    /// let client = Client::builder().build()?;
160    ///
161    /// // Override just the location
162    /// let client = Client::builder().with_location("us-central1").build()?;
163    ///
164    /// // Override project and location
165    /// let client = Client::builder()
166    ///     .with_project("my-project")
167    ///     .with_location("us-central1")
168    ///     .build()?;
169    /// # Ok(())
170    /// # }
171    /// ```
172    pub fn builder() -> ClientBuilder {
173        ClientBuilder::new()
174    }
175
176    /// Create a new client using environment variables for project, location, and credentials.
177    ///
178    /// Reads from:
179    /// - `GOOGLE_CLOUD_PROJECT` (required)
180    /// - `GOOGLE_CLOUD_LOCATION` (optional, defaults to "global")
181    /// - `GOOGLE_CLOUD_SERVICE_ACCOUNT` (optional, for service account impersonation)
182    ///
183    pub fn new() -> Result<Self, VertexAiClientError> {
184        ClientBuilder::new().build()
185    }
186
187    /// Create a client using environment variables for project, location, and credentials.
188    ///
189    /// This is a convenience method that calls the `ProviderClient::from_env()` trait method.
190    /// Reads from:
191    /// - `GOOGLE_CLOUD_PROJECT` (required)
192    /// - `GOOGLE_CLOUD_LOCATION` (optional, defaults to "global")
193    /// - `GOOGLE_CLOUD_SERVICE_ACCOUNT` (optional, for service account impersonation)
194    pub fn from_env() -> Result<Self, VertexAiClientError> {
195        <Self as ProviderClient>::from_env()
196    }
197
198    pub fn project(&self) -> &str {
199        &self.project
200    }
201
202    pub fn location(&self) -> &str {
203        &self.location
204    }
205
206    pub async fn get_inner(
207        &self,
208    ) -> Result<&vertexai::client::PredictionService, VertexAiClientError> {
209        let credentials = self.credentials.clone();
210        self.vertex_client
211            .get_or_init(|| async {
212                let mut builder = vertexai::client::PredictionService::builder();
213                builder = builder.with_credentials(credentials);
214                builder
215                    .build()
216                    .await
217                    .map_err(|error| VertexAiClientError::PredictionService(error.to_string()))
218            })
219            .await
220            .as_ref()
221            .map_err(Clone::clone)
222    }
223}
224
225impl ProviderClient for Client {
226    type Input = Nothing;
227    type Error = VertexAiClientError;
228
229    fn from_env() -> Result<Self, Self::Error>
230    where
231        Self: Sized,
232    {
233        Client::new()
234    }
235
236    fn from_val(_: Self::Input) -> Result<Self, Self::Error>
237    where
238        Self: Sized,
239    {
240        Err(VertexAiClientError::InvalidInput)
241    }
242}
243
244impl CompletionClient for Client {
245    type CompletionModel = CompletionModel;
246
247    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
248        CompletionModel::new(self.clone(), model.into())
249    }
250}
251
252impl VerifyClient for Client {
253    async fn verify(&self) -> Result<(), VerifyError> {
254        // No API endpoint to verify credentials - they're validated on first use
255        Ok(())
256    }
257}