use crate::completion::CompletionModel;
use google_cloud_aiplatform_v1 as vertexai;
use google_cloud_auth::credentials;
use google_cloud_auth::credentials::Credentials;
use rig::client::{CompletionClient, Nothing};
use rig::prelude::*;
use std::sync::Arc;
use tokio::sync::OnceCell;
pub const DEFAULT_LOCATION: &str = "global";
fn build_credentials(explicit_creds: Option<Credentials>) -> Result<Credentials, String> {
if let Some(creds) = explicit_creds {
Ok(creds)
} else {
let source_credentials = credentials::Builder::default()
.build()
.map_err(|e| format!("Failed to build source credentials: {e}"))?;
if let Ok(service_account) = std::env::var("GOOGLE_CLOUD_SERVICE_ACCOUNT") {
credentials::impersonated::Builder::from_source_credentials(source_credentials)
.with_target_principal(service_account)
.build()
.map_err(|e| format!("Failed to build impersonated credentials: {e}"))
} else {
Ok(source_credentials)
}
}
}
#[derive(Clone, Debug)]
pub struct ClientBuilder {
project: Option<String>,
location: Option<String>,
credentials: Option<Credentials>,
}
impl ClientBuilder {
pub fn new() -> Self {
Self {
project: None,
location: None,
credentials: None,
}
}
pub fn with_project(mut self, project: &str) -> Self {
self.project = Some(project.to_string());
self
}
pub fn with_location(mut self, location: &str) -> Self {
self.location = Some(location.to_string());
self
}
pub fn with_credentials(mut self, credentials: Credentials) -> Self {
self.credentials = Some(credentials);
self
}
pub fn build(self) -> Result<Client, String> {
let project = self
.project
.or_else(|| std::env::var("GOOGLE_CLOUD_PROJECT").ok())
.ok_or_else(|| {
"Google Cloud project is required. Set it via ClientBuilder::with_project() or GOOGLE_CLOUD_PROJECT environment variable".to_string()
})?;
let location = self
.location
.or_else(|| std::env::var("GOOGLE_CLOUD_LOCATION").ok())
.unwrap_or_else(|| DEFAULT_LOCATION.to_string());
let credentials = build_credentials(self.credentials)?;
Ok(Client {
project,
location,
credentials,
vertex_client: Arc::new(OnceCell::new()),
})
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct Client {
project: String,
location: String,
credentials: Credentials,
pub(crate) vertex_client: Arc<OnceCell<vertexai::client::PredictionService>>,
}
impl Client {
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub fn new() -> Self {
ClientBuilder::new()
.build()
.expect("Failed to build Vertex AI client. Make sure GOOGLE_CLOUD_PROJECT is set and credentials are configured (e.g., via 'gcloud auth application-default login')")
}
pub fn from_env() -> Self {
<Self as ProviderClient>::from_env()
}
pub fn project(&self) -> &str {
&self.project
}
pub fn location(&self) -> &str {
&self.location
}
pub async fn get_inner(&self) -> &vertexai::client::PredictionService {
let credentials = self.credentials.clone();
self.vertex_client
.get_or_init(|| async {
let mut builder = vertexai::client::PredictionService::builder();
builder = builder.with_credentials(credentials);
builder
.build()
.await
.expect("Failed to build Vertex AI client. Make sure you have Google Cloud credentials configured (e.g., via 'gcloud auth application-default login')")
})
.await
}
}
impl Default for Client {
fn default() -> Self {
Client::new()
}
}
impl ProviderClient for Client {
type Input = Nothing;
fn from_env() -> Self
where
Self: Sized,
{
Client::new()
}
fn from_val(_: Self::Input) -> Self
where
Self: Sized,
{
panic!(
"Vertex AI uses Application Default Credentials (ADC). Use `Client::from_env()` for default credentials, or `Client::new().with_credentials(...).build()` for custom credentials."
);
}
}
impl CompletionClient for Client {
type CompletionModel = CompletionModel;
fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
CompletionModel::new(self.clone(), model.into())
}
}
impl VerifyClient for Client {
async fn verify(&self) -> Result<(), VerifyError> {
Ok(())
}
}