use crate::completion::CompletionModel;
use google_cloud_aiplatform_v1 as vertexai;
use google_cloud_auth::credentials;
use google_cloud_auth::credentials::Credentials;
use rig::client::{CompletionClient, Nothing};
use rig::prelude::*;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::OnceCell;
pub const DEFAULT_LOCATION: &str = "global";
#[derive(Clone, Debug, Error)]
pub enum VertexAiClientError {
#[error(
"Google Cloud project is required. Set it via `ClientBuilder::with_project()` or `GOOGLE_CLOUD_PROJECT`"
)]
MissingProject,
#[error("failed to build source credentials: {0}")]
SourceCredentials(String),
#[error("failed to build impersonated credentials: {0}")]
ImpersonatedCredentials(String),
#[error("failed to build Vertex AI prediction service: {0}")]
PredictionService(String),
#[error(
"Vertex AI uses Application Default Credentials (ADC). Use `Client::from_env()` for default credentials or `Client::builder().with_credentials(...).build()` for explicit credentials."
)]
InvalidInput,
}
fn build_credentials(
explicit_creds: Option<Credentials>,
) -> Result<Credentials, VertexAiClientError> {
if let Some(creds) = explicit_creds {
Ok(creds)
} else {
let source_credentials = credentials::Builder::default()
.build()
.map_err(|e| VertexAiClientError::SourceCredentials(e.to_string()))?;
if let Ok(service_account) = std::env::var("GOOGLE_CLOUD_SERVICE_ACCOUNT") {
credentials::impersonated::Builder::from_source_credentials(source_credentials)
.with_target_principal(service_account)
.build()
.map_err(|e| VertexAiClientError::ImpersonatedCredentials(e.to_string()))
} else {
Ok(source_credentials)
}
}
}
#[derive(Clone, Debug)]
pub struct ClientBuilder {
project: Option<String>,
location: Option<String>,
credentials: Option<Credentials>,
}
impl ClientBuilder {
pub fn new() -> Self {
Self {
project: None,
location: None,
credentials: None,
}
}
pub fn with_project(mut self, project: &str) -> Self {
self.project = Some(project.to_string());
self
}
pub fn with_location(mut self, location: &str) -> Self {
self.location = Some(location.to_string());
self
}
pub fn with_credentials(mut self, credentials: Credentials) -> Self {
self.credentials = Some(credentials);
self
}
pub fn build(self) -> Result<Client, VertexAiClientError> {
let project = self
.project
.or_else(|| std::env::var("GOOGLE_CLOUD_PROJECT").ok())
.ok_or(VertexAiClientError::MissingProject)?;
let location = self
.location
.or_else(|| std::env::var("GOOGLE_CLOUD_LOCATION").ok())
.unwrap_or_else(|| DEFAULT_LOCATION.to_string());
let credentials = build_credentials(self.credentials)?;
Ok(Client {
project,
location,
credentials,
vertex_client: Arc::new(OnceCell::new()),
})
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct Client {
project: String,
location: String,
credentials: Credentials,
pub(crate) vertex_client:
Arc<OnceCell<Result<vertexai::client::PredictionService, VertexAiClientError>>>,
}
impl Client {
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub fn new() -> Result<Self, VertexAiClientError> {
ClientBuilder::new().build()
}
pub fn from_env() -> Result<Self, VertexAiClientError> {
<Self as ProviderClient>::from_env()
}
pub fn project(&self) -> &str {
&self.project
}
pub fn location(&self) -> &str {
&self.location
}
pub async fn get_inner(
&self,
) -> Result<&vertexai::client::PredictionService, VertexAiClientError> {
let credentials = self.credentials.clone();
self.vertex_client
.get_or_init(|| async {
let mut builder = vertexai::client::PredictionService::builder();
builder = builder.with_credentials(credentials);
builder
.build()
.await
.map_err(|error| VertexAiClientError::PredictionService(error.to_string()))
})
.await
.as_ref()
.map_err(Clone::clone)
}
}
impl ProviderClient for Client {
type Input = Nothing;
type Error = VertexAiClientError;
fn from_env() -> Result<Self, Self::Error>
where
Self: Sized,
{
Client::new()
}
fn from_val(_: Self::Input) -> Result<Self, Self::Error>
where
Self: Sized,
{
Err(VertexAiClientError::InvalidInput)
}
}
impl CompletionClient for Client {
type CompletionModel = CompletionModel;
fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
CompletionModel::new(self.clone(), model.into())
}
}
impl VerifyClient for Client {
async fn verify(&self) -> Result<(), VerifyError> {
Ok(())
}
}